mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-09 04:38:03 +02:00
nir/lower_subgroups: scan/reduce for multiple ballot components
lower_scan_reduce only worked when ballot_components equals one. This commit adds support for arbitrary ballot_components. Signed-off-by: Job Noorman <jnoorman@igalia.com> Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io> Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31587>
This commit is contained in:
parent
58b199f7ed
commit
509606e56d
1 changed files with 46 additions and 17 deletions
|
|
@ -742,9 +742,10 @@ build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
|||
static nir_def *
|
||||
build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
||||
nir_def *data, nir_def *mask, unsigned max_mask_bits,
|
||||
unsigned subgroup_size)
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, subgroup_size);
|
||||
nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components,
|
||||
options->ballot_bit_size);
|
||||
|
||||
/* Mask of all channels whose values we need to accumulate. Our own value
|
||||
* is already in accum, if inclusive, thanks to the initialization above.
|
||||
|
|
@ -756,8 +757,8 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
|||
/* At each step, our buddy channel is the first channel we have yet to
|
||||
* take into account in the accumulator.
|
||||
*/
|
||||
nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
|
||||
nir_def *buddy = nir_ufind_msb(b, remaining);
|
||||
nir_def *has_buddy = nir_bany_inequal(b, remaining, nir_imm_int(b, 0));
|
||||
nir_def *buddy = nir_ballot_find_msb(b, 32, remaining);
|
||||
|
||||
/* Accumulate with our buddy channel, if any */
|
||||
nir_def *buddy_data = nir_shuffle(b, data, buddy);
|
||||
|
|
@ -781,8 +782,8 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
|||
* code is cleaner this way.
|
||||
*/
|
||||
nir_def *lower = nir_iand(b, mask, lt_mask);
|
||||
nir_def *has_buddy = nir_ine_imm(b, lower, 0);
|
||||
nir_def *buddy = nir_ufind_msb(b, lower);
|
||||
nir_def *has_buddy = nir_bany_inequal(b, lower, nir_imm_int(b, 0));
|
||||
nir_def *buddy = nir_ballot_find_msb(b, 32, lower);
|
||||
|
||||
nir_def *buddy_data = nir_shuffle(b, data, buddy);
|
||||
nir_def *identity = build_identity(b, data->bit_size, red_op);
|
||||
|
|
@ -794,7 +795,7 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
|||
|
||||
case nir_intrinsic_reduce: {
|
||||
/* For reductions, we need to take the top value of the scan */
|
||||
nir_def *idx = nir_ufind_msb(b, mask);
|
||||
nir_def *idx = nir_ballot_find_msb(b, 32, mask);
|
||||
return nir_shuffle(b, data, idx);
|
||||
}
|
||||
|
||||
|
|
@ -804,20 +805,47 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
|
|||
}
|
||||
|
||||
static nir_def *
|
||||
build_cluster_mask(nir_builder *b, unsigned cluster_size)
|
||||
build_cluster_mask(nir_builder *b, unsigned cluster_size,
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
nir_def *idx = nir_load_subgroup_invocation(b);
|
||||
nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
|
||||
|
||||
nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
|
||||
return nir_ishl(b, cluster_mask, cluster);
|
||||
if (cluster_size <= options->ballot_bit_size) {
|
||||
return build_ballot_imm_ishl(b, BITFIELD_MASK(cluster_size), cluster,
|
||||
options);
|
||||
}
|
||||
|
||||
/* Since the cluster size and the ballot bit size are both powers of 2,
|
||||
* cluster size will be a multiple of the ballot bit size. Therefore, each
|
||||
* ballot component will be either all ones or all zeros. Build a vec for
|
||||
* which each component holds the value of `cluster` for which the mask
|
||||
* should be all ones.
|
||||
*/
|
||||
nir_const_value cluster_sel_const[4];
|
||||
assert(ARRAY_SIZE(cluster_sel_const) >= options->ballot_components);
|
||||
|
||||
for (unsigned i = 0; i < options->ballot_components; i++) {
|
||||
unsigned cluster_val =
|
||||
ROUND_DOWN_TO(i * options->ballot_bit_size, cluster_size);
|
||||
cluster_sel_const[i] =
|
||||
nir_const_value_for_uint(cluster_val, options->ballot_bit_size);
|
||||
}
|
||||
|
||||
nir_def *cluster_sel =
|
||||
nir_build_imm(b, options->ballot_components, options->ballot_bit_size,
|
||||
cluster_sel_const);
|
||||
nir_def *ones = nir_imm_intN_t(b, -1, options->ballot_bit_size);
|
||||
nir_def *zeros = nir_imm_intN_t(b, 0, options->ballot_bit_size);
|
||||
return nir_bcsel(b, nir_ieq(b, cluster, cluster_sel), ones, zeros);
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
|
||||
unsigned subgroup_size)
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
const nir_op red_op = nir_intrinsic_reduction_op(intrin);
|
||||
unsigned subgroup_size = options->subgroup_size;
|
||||
|
||||
/* Grab the cluster size */
|
||||
unsigned cluster_size = subgroup_size;
|
||||
|
|
@ -828,10 +856,11 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
}
|
||||
|
||||
/* Check if all invocations are active. If so, we use the fast path. */
|
||||
nir_def *mask = nir_ballot(b, 1, subgroup_size, nir_imm_true(b));
|
||||
nir_def *mask = nir_ballot(b, options->ballot_components,
|
||||
options->ballot_bit_size, nir_imm_true(b));
|
||||
|
||||
nir_def *full, *partial;
|
||||
nir_push_if(b, nir_ieq_imm(b, mask, -1));
|
||||
nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options)));
|
||||
{
|
||||
full = build_scan_full(b, intrin->intrinsic, red_op,
|
||||
intrin->src[0].ssa, cluster_size);
|
||||
|
|
@ -840,13 +869,13 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
{
|
||||
/* Mask according to the cluster size */
|
||||
if (cluster_size < subgroup_size) {
|
||||
nir_def *cluster_mask = build_cluster_mask(b, cluster_size);
|
||||
nir_def *cluster_mask = build_cluster_mask(b, cluster_size, options);
|
||||
mask = nir_iand(b, mask, cluster_mask);
|
||||
}
|
||||
|
||||
partial = build_scan_reduce(b, intrin->intrinsic, red_op,
|
||||
intrin->src[0].ssa, mask, cluster_size,
|
||||
subgroup_size);
|
||||
options);
|
||||
}
|
||||
nir_pop_if(b, NULL);
|
||||
return nir_if_phi(b, full, partial);
|
||||
|
|
@ -1261,7 +1290,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
|
|||
(options->lower_boolean_reduce || options->lower_reduce))
|
||||
return lower_boolean_reduce(b, intrin, options);
|
||||
if (options->lower_reduce)
|
||||
return lower_scan_reduce(b, intrin, options->subgroup_size);
|
||||
return lower_scan_reduce(b, intrin, options);
|
||||
return ret;
|
||||
}
|
||||
case nir_intrinsic_inclusive_scan:
|
||||
|
|
@ -1272,7 +1301,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
|
|||
(options->lower_boolean_reduce || options->lower_reduce))
|
||||
return lower_boolean_reduce(b, intrin, options);
|
||||
if (options->lower_reduce)
|
||||
return lower_scan_reduce(b, intrin, options->subgroup_size);
|
||||
return lower_scan_reduce(b, intrin, options);
|
||||
break;
|
||||
|
||||
case nir_intrinsic_rotate:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue