nir/opt_uniform_subgroup: optimize min/max/and/or reduce of bcsel(div, con, con)

Foz-DB Navi48:
Totals from 1 (0.00% of 97397) affected shaders:
Instrs: 1848 -> 1834 (-0.76%)
CodeSize: 9996 -> 9908 (-0.88%)
VGPRs: 96 -> 72 (-25.00%)
Latency: 17371 -> 17358 (-0.07%)
Copies: 190 -> 191 (+0.53%)
PreVGPRs: 43 -> 41 (-4.65%)
VALU: 657 -> 648 (-1.37%)

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38974>
This commit is contained in:
Georg Lehmann 2025-12-16 15:30:49 +01:00 committed by Marge Bot
parent 4d8cc7d82e
commit 0e5e1cb9b0

View file

@ -11,6 +11,89 @@
#include "nir/nir.h"
#include "nir/nir_builder.h"
struct select_info {
nir_def *cond;
nir_def *values[2];
};
static bool
parse_select_of_con_values(nir_builder *b, nir_def *def, struct select_info *info)
{
if (!nir_def_is_alu(def))
return false;
nir_alu_instr *alu = nir_def_as_alu(def);
unsigned bit_size = def->bit_size;
unsigned num_components = def->num_components;
nir_block *use_block = nir_cursor_current_block(b->cursor);
switch (alu->op) {
case nir_op_b2f16:
case nir_op_b2f32:
case nir_op_b2f64: {
info->cond = nir_mov_alu(b, alu->src[0], num_components);
for (unsigned i = 0; i < 2; i++)
info->values[i] = nir_imm_floatN_t(b, i ? 0.0 : 1.0, bit_size);
return true;
}
case nir_op_b2i8:
case nir_op_b2i16:
case nir_op_b2i32:
case nir_op_b2i64: {
info->cond = nir_mov_alu(b, alu->src[0], num_components);
for (unsigned i = 0; i < 2; i++)
info->values[i] = nir_imm_intN_t(b, i ? 0 : 1, bit_size);
return true;
}
case nir_op_fneg:
case nir_op_ineg: {
/* nir_opt_algebraic canonicalizes a ? -1 : 0 to neg(b2f/b2i(a)),
* so look for this here.
*/
nir_alu_instr *b2t = nir_def_as_alu_or_null(alu->src[0].src.ssa);
bool is_float = alu->op == nir_op_fneg;
nir_alu_type dest_type = (is_float ? nir_type_float : nir_type_uint) | bit_size;
nir_op b2t_op = nir_type_conversion_op(nir_type_bool1, dest_type, nir_rounding_mode_undef);
if (!b2t || b2t->op != b2t_op)
return false;
nir_alu_src neg_src = { NIR_SRC_INIT };
neg_src.src = nir_src_for_ssa(nir_mov_alu(b, b2t->src[0], b2t->def.num_components));
memcpy(neg_src.swizzle, alu->src[0].swizzle, sizeof(alu->src[0].swizzle));
info->cond = nir_mov_alu(b, neg_src, num_components);
for (unsigned i = 0; i < 2; i++) {
if (is_float)
info->values[i] = nir_imm_floatN_t(b, i ? -0.0 : -1.0, bit_size);
else
info->values[i] = nir_imm_intN_t(b, i ? 0 : -1, bit_size);
}
return true;
}
case nir_op_bcsel: {
for (unsigned i = 0; i < 2; i++) {
if (nir_def_is_divergent_at_use_block(alu->src[1 + i].src.ssa, use_block))
return false;
}
info->cond = nir_mov_alu(b, alu->src[0], num_components);
for (unsigned i = 0; i < 2; i++)
info->values[i] = nir_mov_alu(b, alu->src[1 + i], num_components);
return true;
}
default:
return false;
}
}
static nir_def *
ballot_bit_count(nir_builder *b, nir_def *ballot)
{
@ -99,15 +182,14 @@ opt_uniform_subgroup_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *_s
case nir_intrinsic_reduce:
case nir_intrinsic_exclusive_scan:
case nir_intrinsic_inclusive_scan: {
if (nir_src_is_divergent(&intrin->src[0]))
return false;
const nir_op reduction_op = (nir_op)nir_intrinsic_reduction_op(intrin);
switch (reduction_op) {
case nir_op_iadd:
case nir_op_fadd:
case nir_op_ixor: {
if (nir_src_is_divergent(&intrin->src[0]))
return false;
if (nir_intrinsic_has_cluster_size(intrin) && nir_intrinsic_cluster_size(intrin))
return false;
nir_def *count;
@ -151,7 +233,26 @@ opt_uniform_subgroup_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *_s
case nir_op_ior:
if (intrin->intrinsic == nir_intrinsic_exclusive_scan)
return false;
replacement = intrin->src[0].ssa;
if (!nir_src_is_divergent(&intrin->src[0])) {
replacement = intrin->src[0].ssa;
} else {
if (intrin->intrinsic != nir_intrinsic_reduce)
return false;
if (nir_intrinsic_cluster_size(intrin))
return false;
if (intrin->def.num_components != 1)
return false;
struct select_info sel;
if (!parse_select_of_con_values(b, intrin->src[0].ssa, &sel))
return false;
nir_def *mix_value = nir_build_alu2(b, reduction_op, sel.values[0], sel.values[1]);
replacement = nir_bcsel(b, nir_vote_any(b, 1, sel.cond),
nir_bcsel(b, nir_vote_all(b, 1, sel.cond), sel.values[0], mix_value),
sel.values[1]);
}
break;
default: