From 0e5e1cb9b02a9057e98be6b91c3e528ee08e0970 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Tue, 16 Dec 2025 15:30:49 +0100 Subject: [PATCH] nir/opt_uniform_subgroup: optimize min/max/and/or reduce of bcsel(div, con, con) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foz-DB Navi48: Totals from 1 (0.00% of 97397) affected shaders: Instrs: 1848 -> 1834 (-0.76%) CodeSize: 9996 -> 9908 (-0.88%) VGPRs: 96 -> 72 (-25.00%) Latency: 17371 -> 17358 (-0.07%) Copies: 190 -> 191 (+0.53%) PreVGPRs: 43 -> 41 (-4.65%) VALU: 657 -> 648 (-1.37%) Reviewed-by: Daniel Schürmann Part-of: --- src/compiler/nir/nir_opt_uniform_subgroup.c | 109 +++++++++++++++++++- 1 file changed, 105 insertions(+), 4 deletions(-) diff --git a/src/compiler/nir/nir_opt_uniform_subgroup.c b/src/compiler/nir/nir_opt_uniform_subgroup.c index f29bc54aaa5..b4b39b1e6fb 100644 --- a/src/compiler/nir/nir_opt_uniform_subgroup.c +++ b/src/compiler/nir/nir_opt_uniform_subgroup.c @@ -11,6 +11,89 @@ #include "nir/nir.h" #include "nir/nir_builder.h" +struct select_info { + nir_def *cond; + nir_def *values[2]; +}; + +static bool +parse_select_of_con_values(nir_builder *b, nir_def *def, struct select_info *info) +{ + if (!nir_def_is_alu(def)) + return false; + + nir_alu_instr *alu = nir_def_as_alu(def); + unsigned bit_size = def->bit_size; + unsigned num_components = def->num_components; + nir_block *use_block = nir_cursor_current_block(b->cursor); + + switch (alu->op) { + case nir_op_b2f16: + case nir_op_b2f32: + case nir_op_b2f64: { + info->cond = nir_mov_alu(b, alu->src[0], num_components); + for (unsigned i = 0; i < 2; i++) + info->values[i] = nir_imm_floatN_t(b, i ? 0.0 : 1.0, bit_size); + + return true; + } + case nir_op_b2i8: + case nir_op_b2i16: + case nir_op_b2i32: + case nir_op_b2i64: { + info->cond = nir_mov_alu(b, alu->src[0], num_components); + for (unsigned i = 0; i < 2; i++) + info->values[i] = nir_imm_intN_t(b, i ? 0 : 1, bit_size); + + return true; + } + case nir_op_fneg: + case nir_op_ineg: { + /* nir_opt_algebraic canonicalizes a ? -1 : 0 to neg(b2f/b2i(a)), + * so look for this here. + */ + nir_alu_instr *b2t = nir_def_as_alu_or_null(alu->src[0].src.ssa); + + bool is_float = alu->op == nir_op_fneg; + nir_alu_type dest_type = (is_float ? nir_type_float : nir_type_uint) | bit_size; + + nir_op b2t_op = nir_type_conversion_op(nir_type_bool1, dest_type, nir_rounding_mode_undef); + + if (!b2t || b2t->op != b2t_op) + return false; + + nir_alu_src neg_src = { NIR_SRC_INIT }; + neg_src.src = nir_src_for_ssa(nir_mov_alu(b, b2t->src[0], b2t->def.num_components)); + memcpy(neg_src.swizzle, alu->src[0].swizzle, sizeof(alu->src[0].swizzle)); + + info->cond = nir_mov_alu(b, neg_src, num_components); + + for (unsigned i = 0; i < 2; i++) { + if (is_float) + info->values[i] = nir_imm_floatN_t(b, i ? -0.0 : -1.0, bit_size); + else + info->values[i] = nir_imm_intN_t(b, i ? 0 : -1, bit_size); + } + + return true; + } + case nir_op_bcsel: { + for (unsigned i = 0; i < 2; i++) { + if (nir_def_is_divergent_at_use_block(alu->src[1 + i].src.ssa, use_block)) + return false; + } + + info->cond = nir_mov_alu(b, alu->src[0], num_components); + for (unsigned i = 0; i < 2; i++) + info->values[i] = nir_mov_alu(b, alu->src[1 + i], num_components); + + return true; + } + default: + return false; + } +} + static nir_def * ballot_bit_count(nir_builder *b, nir_def *ballot) { @@ -99,15 +182,14 @@ opt_uniform_subgroup_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *_s case nir_intrinsic_reduce: case nir_intrinsic_exclusive_scan: case nir_intrinsic_inclusive_scan: { - if (nir_src_is_divergent(&intrin->src[0])) - return false; - const nir_op reduction_op = (nir_op)nir_intrinsic_reduction_op(intrin); switch (reduction_op) { case nir_op_iadd: case nir_op_fadd: case nir_op_ixor: { + if (nir_src_is_divergent(&intrin->src[0])) + return false; if (nir_intrinsic_has_cluster_size(intrin) && nir_intrinsic_cluster_size(intrin)) return false; nir_def *count; @@ -151,7 +233,26 @@ opt_uniform_subgroup_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *_s case nir_op_ior: if (intrin->intrinsic == nir_intrinsic_exclusive_scan) return false; - replacement = intrin->src[0].ssa; + if (!nir_src_is_divergent(&intrin->src[0])) { + replacement = intrin->src[0].ssa; + } else { + if (intrin->intrinsic != nir_intrinsic_reduce) + return false; + if (nir_intrinsic_cluster_size(intrin)) + return false; + if (intrin->def.num_components != 1) + return false; + + struct select_info sel; + if (!parse_select_of_con_values(b, intrin->src[0].ssa, &sel)) + return false; + + nir_def *mix_value = nir_build_alu2(b, reduction_op, sel.values[0], sel.values[1]); + + replacement = nir_bcsel(b, nir_vote_any(b, 1, sel.cond), + nir_bcsel(b, nir_vote_all(b, 1, sel.cond), sel.values[0], mix_value), + sel.values[1]); + } break; default: