nir/lower_subgroups: scan/reduce for multiple ballot components

lower_scan_reduce only worked when ballot_components equals one. This
commit adds support for arbitrary ballot_components.

Signed-off-by: Job Noorman <jnoorman@igalia.com>
Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31587>
This commit is contained in:
Job Noorman 2024-10-17 21:44:39 +02:00 committed by Marge Bot
parent 58b199f7ed
commit 509606e56d

View file

@ -742,9 +742,10 @@ build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
static nir_def *
build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
nir_def *data, nir_def *mask, unsigned max_mask_bits,
unsigned subgroup_size)
const nir_lower_subgroups_options *options)
{
nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, subgroup_size);
nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components,
options->ballot_bit_size);
/* Mask of all channels whose values we need to accumulate. Our own value
* is already in accum, if inclusive, thanks to the initialization above.
@ -756,8 +757,8 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
/* At each step, our buddy channel is the first channel we have yet to
* take into account in the accumulator.
*/
nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
nir_def *buddy = nir_ufind_msb(b, remaining);
nir_def *has_buddy = nir_bany_inequal(b, remaining, nir_imm_int(b, 0));
nir_def *buddy = nir_ballot_find_msb(b, 32, remaining);
/* Accumulate with our buddy channel, if any */
nir_def *buddy_data = nir_shuffle(b, data, buddy);
@ -781,8 +782,8 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
* code is cleaner this way.
*/
nir_def *lower = nir_iand(b, mask, lt_mask);
nir_def *has_buddy = nir_ine_imm(b, lower, 0);
nir_def *buddy = nir_ufind_msb(b, lower);
nir_def *has_buddy = nir_bany_inequal(b, lower, nir_imm_int(b, 0));
nir_def *buddy = nir_ballot_find_msb(b, 32, lower);
nir_def *buddy_data = nir_shuffle(b, data, buddy);
nir_def *identity = build_identity(b, data->bit_size, red_op);
@ -794,7 +795,7 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
case nir_intrinsic_reduce: {
/* For reductions, we need to take the top value of the scan */
nir_def *idx = nir_ufind_msb(b, mask);
nir_def *idx = nir_ballot_find_msb(b, 32, mask);
return nir_shuffle(b, data, idx);
}
@ -804,20 +805,47 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
}
static nir_def *
build_cluster_mask(nir_builder *b, unsigned cluster_size)
build_cluster_mask(nir_builder *b, unsigned cluster_size,
const nir_lower_subgroups_options *options)
{
nir_def *idx = nir_load_subgroup_invocation(b);
nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
return nir_ishl(b, cluster_mask, cluster);
if (cluster_size <= options->ballot_bit_size) {
return build_ballot_imm_ishl(b, BITFIELD_MASK(cluster_size), cluster,
options);
}
/* Since the cluster size and the ballot bit size are both powers of 2,
* cluster size will be a multiple of the ballot bit size. Therefore, each
* ballot component will be either all ones or all zeros. Build a vec for
* which each component holds the value of `cluster` for which the mask
* should be all ones.
*/
nir_const_value cluster_sel_const[4];
assert(ARRAY_SIZE(cluster_sel_const) >= options->ballot_components);
for (unsigned i = 0; i < options->ballot_components; i++) {
unsigned cluster_val =
ROUND_DOWN_TO(i * options->ballot_bit_size, cluster_size);
cluster_sel_const[i] =
nir_const_value_for_uint(cluster_val, options->ballot_bit_size);
}
nir_def *cluster_sel =
nir_build_imm(b, options->ballot_components, options->ballot_bit_size,
cluster_sel_const);
nir_def *ones = nir_imm_intN_t(b, -1, options->ballot_bit_size);
nir_def *zeros = nir_imm_intN_t(b, 0, options->ballot_bit_size);
return nir_bcsel(b, nir_ieq(b, cluster, cluster_sel), ones, zeros);
}
static nir_def *
lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
unsigned subgroup_size)
const nir_lower_subgroups_options *options)
{
const nir_op red_op = nir_intrinsic_reduction_op(intrin);
unsigned subgroup_size = options->subgroup_size;
/* Grab the cluster size */
unsigned cluster_size = subgroup_size;
@ -828,10 +856,11 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
}
/* Check if all invocations are active. If so, we use the fast path. */
nir_def *mask = nir_ballot(b, 1, subgroup_size, nir_imm_true(b));
nir_def *mask = nir_ballot(b, options->ballot_components,
options->ballot_bit_size, nir_imm_true(b));
nir_def *full, *partial;
nir_push_if(b, nir_ieq_imm(b, mask, -1));
nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options)));
{
full = build_scan_full(b, intrin->intrinsic, red_op,
intrin->src[0].ssa, cluster_size);
@ -840,13 +869,13 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
{
/* Mask according to the cluster size */
if (cluster_size < subgroup_size) {
nir_def *cluster_mask = build_cluster_mask(b, cluster_size);
nir_def *cluster_mask = build_cluster_mask(b, cluster_size, options);
mask = nir_iand(b, mask, cluster_mask);
}
partial = build_scan_reduce(b, intrin->intrinsic, red_op,
intrin->src[0].ssa, mask, cluster_size,
subgroup_size);
options);
}
nir_pop_if(b, NULL);
return nir_if_phi(b, full, partial);
@ -1261,7 +1290,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
(options->lower_boolean_reduce || options->lower_reduce))
return lower_boolean_reduce(b, intrin, options);
if (options->lower_reduce)
return lower_scan_reduce(b, intrin, options->subgroup_size);
return lower_scan_reduce(b, intrin, options);
return ret;
}
case nir_intrinsic_inclusive_scan:
@ -1272,7 +1301,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
(options->lower_boolean_reduce || options->lower_reduce))
return lower_boolean_reduce(b, intrin, options);
if (options->lower_reduce)
return lower_scan_reduce(b, intrin, options->subgroup_size);
return lower_scan_reduce(b, intrin, options);
break;
case nir_intrinsic_rotate: