From 509606e56d88ebd7c015ae0d92422536d4c52660 Mon Sep 17 00:00:00 2001 From: Job Noorman Date: Thu, 17 Oct 2024 21:44:39 +0200 Subject: [PATCH] nir/lower_subgroups: scan/reduce for multiple ballot components lower_scan_reduce only worked when ballot_components equals one. This commit adds support for arbitrary ballot_components. Signed-off-by: Job Noorman Reviewed-by: Alyssa Rosenzweig Reviewed-by: Faith Ekstrand Part-of: --- src/compiler/nir/nir_lower_subgroups.c | 63 +++++++++++++++++++------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 263191d8675..20b486c2956 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -742,9 +742,10 @@ build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op, static nir_def * build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, nir_def *data, nir_def *mask, unsigned max_mask_bits, - unsigned subgroup_size) + const nir_lower_subgroups_options *options) { - nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, subgroup_size); + nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components, + options->ballot_bit_size); /* Mask of all channels whose values we need to accumulate. Our own value * is already in accum, if inclusive, thanks to the initialization above. @@ -756,8 +757,8 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, /* At each step, our buddy channel is the first channel we have yet to * take into account in the accumulator. */ - nir_def *has_buddy = nir_ine_imm(b, remaining, 0); - nir_def *buddy = nir_ufind_msb(b, remaining); + nir_def *has_buddy = nir_bany_inequal(b, remaining, nir_imm_int(b, 0)); + nir_def *buddy = nir_ballot_find_msb(b, 32, remaining); /* Accumulate with our buddy channel, if any */ nir_def *buddy_data = nir_shuffle(b, data, buddy); @@ -781,8 +782,8 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, * code is cleaner this way. */ nir_def *lower = nir_iand(b, mask, lt_mask); - nir_def *has_buddy = nir_ine_imm(b, lower, 0); - nir_def *buddy = nir_ufind_msb(b, lower); + nir_def *has_buddy = nir_bany_inequal(b, lower, nir_imm_int(b, 0)); + nir_def *buddy = nir_ballot_find_msb(b, 32, lower); nir_def *buddy_data = nir_shuffle(b, data, buddy); nir_def *identity = build_identity(b, data->bit_size, red_op); @@ -794,7 +795,7 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, case nir_intrinsic_reduce: { /* For reductions, we need to take the top value of the scan */ - nir_def *idx = nir_ufind_msb(b, mask); + nir_def *idx = nir_ballot_find_msb(b, 32, mask); return nir_shuffle(b, data, idx); } @@ -804,20 +805,47 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, } static nir_def * -build_cluster_mask(nir_builder *b, unsigned cluster_size) +build_cluster_mask(nir_builder *b, unsigned cluster_size, + const nir_lower_subgroups_options *options) { nir_def *idx = nir_load_subgroup_invocation(b); nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1)); - nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size)); - return nir_ishl(b, cluster_mask, cluster); + if (cluster_size <= options->ballot_bit_size) { + return build_ballot_imm_ishl(b, BITFIELD_MASK(cluster_size), cluster, + options); + } + + /* Since the cluster size and the ballot bit size are both powers of 2, + * cluster size will be a multiple of the ballot bit size. Therefore, each + * ballot component will be either all ones or all zeros. Build a vec for + * which each component holds the value of `cluster` for which the mask + * should be all ones. + */ + nir_const_value cluster_sel_const[4]; + assert(ARRAY_SIZE(cluster_sel_const) >= options->ballot_components); + + for (unsigned i = 0; i < options->ballot_components; i++) { + unsigned cluster_val = + ROUND_DOWN_TO(i * options->ballot_bit_size, cluster_size); + cluster_sel_const[i] = + nir_const_value_for_uint(cluster_val, options->ballot_bit_size); + } + + nir_def *cluster_sel = + nir_build_imm(b, options->ballot_components, options->ballot_bit_size, + cluster_sel_const); + nir_def *ones = nir_imm_intN_t(b, -1, options->ballot_bit_size); + nir_def *zeros = nir_imm_intN_t(b, 0, options->ballot_bit_size); + return nir_bcsel(b, nir_ieq(b, cluster, cluster_sel), ones, zeros); } static nir_def * lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin, - unsigned subgroup_size) + const nir_lower_subgroups_options *options) { const nir_op red_op = nir_intrinsic_reduction_op(intrin); + unsigned subgroup_size = options->subgroup_size; /* Grab the cluster size */ unsigned cluster_size = subgroup_size; @@ -828,10 +856,11 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin, } /* Check if all invocations are active. If so, we use the fast path. */ - nir_def *mask = nir_ballot(b, 1, subgroup_size, nir_imm_true(b)); + nir_def *mask = nir_ballot(b, options->ballot_components, + options->ballot_bit_size, nir_imm_true(b)); nir_def *full, *partial; - nir_push_if(b, nir_ieq_imm(b, mask, -1)); + nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options))); { full = build_scan_full(b, intrin->intrinsic, red_op, intrin->src[0].ssa, cluster_size); @@ -840,13 +869,13 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin, { /* Mask according to the cluster size */ if (cluster_size < subgroup_size) { - nir_def *cluster_mask = build_cluster_mask(b, cluster_size); + nir_def *cluster_mask = build_cluster_mask(b, cluster_size, options); mask = nir_iand(b, mask, cluster_mask); } partial = build_scan_reduce(b, intrin->intrinsic, red_op, intrin->src[0].ssa, mask, cluster_size, - subgroup_size); + options); } nir_pop_if(b, NULL); return nir_if_phi(b, full, partial); @@ -1261,7 +1290,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) (options->lower_boolean_reduce || options->lower_reduce)) return lower_boolean_reduce(b, intrin, options); if (options->lower_reduce) - return lower_scan_reduce(b, intrin, options->subgroup_size); + return lower_scan_reduce(b, intrin, options); return ret; } case nir_intrinsic_inclusive_scan: @@ -1272,7 +1301,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) (options->lower_boolean_reduce || options->lower_reduce)) return lower_boolean_reduce(b, intrin, options); if (options->lower_reduce) - return lower_scan_reduce(b, intrin, options->subgroup_size); + return lower_scan_reduce(b, intrin, options); break; case nir_intrinsic_rotate: