diff --git a/src/compiler/nir/nir_opt_if.c b/src/compiler/nir/nir_opt_if.c index 0fccebf95e6..65904175c6a 100644 --- a/src/compiler/nir/nir_opt_if.c +++ b/src/compiler/nir/nir_opt_if.c @@ -934,6 +934,60 @@ opt_if_simplification(nir_builder *b, nir_if *nif) return true; } +/* Find phi statements after an if that choose between true and false, and + * replace them with the if statement's condition (or an inot of it). + */ +static bool +opt_if_phi_is_condition(nir_builder *b, nir_if *nif) +{ + /* Grab pointers to the last then/else blocks for looking in the phis. */ + nir_block *then_block = nir_if_last_then_block(nif); + nir_block *else_block = nir_if_last_else_block(nif); + nir_ssa_def *cond = nif->condition.ssa; + bool progress = false; + + nir_block *after_if_block = nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node)); + nir_foreach_instr_safe(instr, after_if_block) { + if (instr->type != nir_instr_type_phi) + break; + + nir_phi_instr *phi = nir_instr_as_phi(instr); + if (phi->dest.ssa.bit_size != cond->bit_size || + phi->dest.ssa.num_components != 1) + continue; + + enum opt_bool { + T, F, UNKNOWN + } then_val = UNKNOWN, else_val = UNKNOWN; + + nir_foreach_phi_src(src, phi) { + assert(src->pred == then_block || src->pred == else_block); + enum opt_bool *pred_val = src->pred == then_block ? &then_val : &else_val; + + nir_ssa_scalar val = nir_ssa_scalar_resolved(src->src.ssa, 0); + if (!nir_ssa_scalar_is_const(val)) + break; + + if (nir_ssa_scalar_as_int(val) == -1) + *pred_val = T; + else if (nir_ssa_scalar_as_uint(val) == 0) + *pred_val = F; + else + break; + } + if (then_val == T && else_val == F) { + nir_ssa_def_rewrite_uses(&phi->dest.ssa, cond); + progress = true; + } else if (then_val == F && else_val == T) { + b->cursor = nir_before_cf_node(&nif->cf_node); + nir_ssa_def_rewrite_uses(&phi->dest.ssa, nir_inot(b, cond)); + progress = true; + } + } + + return progress; +} + /** * This optimization tries to merge two break statements into a single break. * For this purpose, it checks if both branch legs end in a break or @@ -1560,6 +1614,7 @@ opt_if_cf_list(nir_builder *b, struct exec_list *cf_list, progress |= opt_if_loop_terminator(nif); progress |= opt_if_merge(nif); progress |= opt_if_simplification(b, nif); + progress |= opt_if_phi_is_condition(b, nif); break; }