diff --git a/src/compiler/nir/nir_opt_if.c b/src/compiler/nir/nir_opt_if.c index b6c69bfd163..10336069972 100644 --- a/src/compiler/nir/nir_opt_if.c +++ b/src/compiler/nir/nir_opt_if.c @@ -1200,6 +1200,111 @@ opt_if_evaluate_condition_use(nir_builder *b, nir_if *nif) return progress; } +static bool +rewrite_comp_uses_within_if(nir_builder *b, nir_if *nif, bool invert, + nir_ssa_scalar scalar, nir_ssa_scalar new_scalar) +{ + bool progress = false; + + nir_block *first = invert ? nir_if_first_else_block(nif) : nir_if_first_then_block(nif); + nir_block *last = invert ? nir_if_last_else_block(nif) : nir_if_last_then_block(nif); + + nir_ssa_def *new_ssa = NULL; + nir_foreach_use_safe(use, scalar.def) { + if (use->parent_instr->block->index < first->index || + use->parent_instr->block->index > last->index) + continue; + + /* Only rewrite users which use only the new component. This is to avoid a + * situation where copy propagation will undo the rewrite and we risk an infinite + * loop. + * + * We could rewrite users which use a mix of the old and new components, but if + * nir_src_components_read() is incomplete, then we risk the new component actually being + * unused and some optimization later undoing the rewrite. + */ + if (nir_src_components_read(use) != BITFIELD64_BIT(scalar.comp)) + continue; + + if (!new_ssa) { + b->cursor = nir_before_cf_node(&nif->cf_node); + new_ssa = nir_channel(b, new_scalar.def, new_scalar.comp); + if (scalar.def->num_components > 1) { + nir_ssa_def *vec = nir_ssa_undef(b, scalar.def->num_components, scalar.def->bit_size); + new_ssa = nir_vector_insert_imm(b, vec, new_ssa, scalar.comp); + } + } + + nir_instr_rewrite_src_ssa(use->parent_instr, use, new_ssa); + progress = true; + } + + return progress; +} + +/* + * This optimization turns: + * + * if (a == (b=readfirstlane(a))) + * use(a) + * if (c == (d=load_const)) + * use(c) + * + * into: + * + * if (a == (b=readfirstlane(a))) + * use(b) + * if (c == (d=load_const)) + * use(d) +*/ +static bool +opt_if_rewrite_uniform_uses(nir_builder *b, nir_if *nif, nir_ssa_scalar cond, bool accept_ine) +{ + bool progress = false; + + if (!nir_ssa_scalar_is_alu(cond)) + return false; + + nir_op op = nir_ssa_scalar_alu_op(cond); + if (op == nir_op_iand) { + progress |= opt_if_rewrite_uniform_uses(b, nif, nir_ssa_scalar_chase_alu_src(cond, 0), false); + progress |= opt_if_rewrite_uniform_uses(b, nif, nir_ssa_scalar_chase_alu_src(cond, 1), false); + return progress; + } + + if (op != nir_op_ieq && (op != nir_op_ine || !accept_ine)) + return false; + + for (unsigned i = 0; i < 2; i++) { + nir_ssa_scalar src_uni = nir_ssa_scalar_chase_alu_src(cond, i); + nir_ssa_scalar src_div = nir_ssa_scalar_chase_alu_src(cond, !i); + + if (src_uni.def->parent_instr->type == nir_instr_type_load_const && src_div.def != src_uni.def) + return rewrite_comp_uses_within_if(b, nif, op == nir_op_ine, src_div, src_uni); + + if (src_uni.def->parent_instr->type != nir_instr_type_intrinsic) + continue; + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(src_uni.def->parent_instr); + if (intrin->intrinsic != nir_intrinsic_read_first_invocation && + (intrin->intrinsic != nir_intrinsic_reduce || nir_intrinsic_cluster_size(intrin))) + continue; + + nir_ssa_scalar intrin_src = {intrin->src[0].ssa, src_uni.comp}; + nir_ssa_scalar resolved_intrin_src = nir_ssa_scalar_resolved(intrin_src.def, intrin_src.comp); + + if (resolved_intrin_src.comp != src_div.comp || resolved_intrin_src.def != src_div.def) + continue; + + progress |= rewrite_comp_uses_within_if(b, nif, op == nir_op_ine, resolved_intrin_src, src_uni); + if (intrin_src.comp != resolved_intrin_src.comp || intrin_src.def != resolved_intrin_src.def) + progress |= rewrite_comp_uses_within_if(b, nif, op == nir_op_ine, intrin_src, src_uni); + + return progress; + } + + return false; +} + static void simple_merge_if(nir_if *dest_if, nir_if *src_if, bool dest_if_then, bool src_if_then) @@ -1387,6 +1492,8 @@ opt_if_safe_cf_list(nir_builder *b, struct exec_list *cf_list) progress |= opt_if_safe_cf_list(b, &nif->then_list); progress |= opt_if_safe_cf_list(b, &nif->else_list); progress |= opt_if_evaluate_condition_use(b, nif); + nir_ssa_scalar cond = nir_ssa_scalar_resolved(nif->condition.ssa, 0); + progress |= opt_if_rewrite_uniform_uses(b, nif, cond, true); break; }