From 9bc14a0047eae19ad89df7512c490767c0747e7f Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Wed, 3 Sep 2025 15:27:46 +0200 Subject: [PATCH] nir/lower_subgroup: optimize reduce/scans with unknown subgroup size We skip iterations with ifs. These can be optimized aways after the subgroup size is known. Every driver should do that because applications depend on it anyway. Reviewed-by: Alyssa Rosenzweig Reviewed-by: Emma Anholt Part-of: --- src/compiler/nir/nir_lower_subgroups.c | 43 ++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index e46bf0d00fa..1f2ac17dda0 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -701,18 +701,32 @@ build_identity(nir_builder *b, unsigned bit_size, nir_op op) /* Implementation of scan/reduce that assumes a full subgroup */ static nir_def * build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op, - nir_def *data, unsigned cluster_size) + nir_def *data, unsigned cluster_size, + const nir_lower_subgroups_options *options) { + bool unknown_size = !options->subgroup_size; + nir_def *subgroup_size = unknown_size ? nir_load_subgroup_size(b) : NULL; + switch (op) { case nir_intrinsic_exclusive_scan: case nir_intrinsic_inclusive_scan: { for (unsigned i = 1; i < cluster_size; i *= 2) { + nir_def *old_data = data; + + if (unknown_size) + nir_push_if(b, nir_ugt_imm(b, subgroup_size, i)); + nir_def *idx = nir_load_subgroup_invocation(b); nir_def *has_buddy = nir_ige_imm(b, idx, i); nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i)); nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data); data = nir_bcsel(b, has_buddy, accum, data); + + if (unknown_size) { + nir_pop_if(b, NULL); + data = nir_if_phi(b, data, old_data); + } } if (op == nir_intrinsic_exclusive_scan) { @@ -732,8 +746,18 @@ build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op, case nir_intrinsic_reduce: { for (unsigned i = 1; i < cluster_size; i *= 2) { + nir_def *old_data = data; + + if (unknown_size) + nir_push_if(b, nir_ugt_imm(b, subgroup_size, i)); + nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i)); data = nir_build_alu2(b, red_op, data, buddy_data); + + if (unknown_size) { + nir_pop_if(b, NULL); + data = nir_if_phi(b, data, old_data); + } } return data; } @@ -749,6 +773,9 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, nir_def *data, nir_def *mask, unsigned max_mask_bits, const nir_lower_subgroups_options *options) { + bool unknown_size = !options->subgroup_size; + nir_def *subgroup_size = unknown_size ? nir_load_subgroup_size(b) : NULL; + nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components, options->ballot_bit_size); @@ -759,6 +786,12 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, nir_def *remaining = nir_iand(b, mask, lt_mask); for (unsigned i = 1; i < max_mask_bits; i *= 2) { + nir_def *old_data = data; + nir_def *old_remaining = remaining; + + if (unknown_size) + nir_push_if(b, nir_ugt_imm(b, subgroup_size, i)); + /* At each step, our buddy channel is the first channel we have yet to * take into account in the accumulator. */ @@ -776,6 +809,12 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, */ nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy); remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0)); + + if (unknown_size) { + nir_pop_if(b, NULL); + data = nir_if_phi(b, data, old_data); + remaining = nir_if_phi(b, remaining, old_remaining); + } } switch (op) { @@ -868,7 +907,7 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin, nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options))); { full = build_scan_full(b, intrin->intrinsic, red_op, - intrin->src[0].ssa, cluster_size); + intrin->src[0].ssa, cluster_size, options); } nir_push_else(b, NULL); {