nir/lower_subgroups: move up some helper functions

build_subgroup_mask and build_ballot_imm_ishl will be needed by other
functions higher-up the file.

Signed-off-by: Job Noorman <jnoorman@igalia.com>
Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31587>
This commit is contained in:
Job Noorman 2024-10-17 21:44:39 +02:00 committed by Marge Bot
parent 085e7e419d
commit e0cb4a94a3

View file

@ -62,6 +62,107 @@ lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
return nir_pack_64_2x32_split(b, &intr_x->def, &intr_y->def);
}
/* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
* 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
* above the subgroup size for the masks, but gt_mask and ge_mask make them 1
* so we have to "and" with this mask.
*/
static nir_def *
build_subgroup_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
{
nir_def *subgroup_size = nir_load_subgroup_size(b);
/* First compute the result assuming one ballot component. */
nir_def *result =
nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
nir_isub_imm(b, options->ballot_bit_size,
subgroup_size));
/* Since the subgroup size and ballot bitsize are both powers of two, there
* are two possible cases to consider:
*
* (1) The subgroup size is less than the ballot bitsize. We need to return
* "result" in the first component and 0 in every other component.
* (2) The subgroup size is a multiple of the ballot bitsize. We need to
* return ~0 if the subgroup size divided by the ballot bitsize is less
* than or equal to the index in the vector and 0 otherwise. For example,
* with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
* to return { ~0, ~0, 0, 0 }.
*
* In case (2) it turns out that "result" will be ~0, because
* "ballot_bit_size - subgroup_size" is also a multiple of
* "ballot_bit_size" and since nir_ushr masks the shift value it will
* shifted by 0. This means that the first component can just be "result"
* in all cases. The other components will also get the correct value in
* case (1) if we just use the rule in case (2), so we'll get the correct
* result if we just follow (2) and then replace the first component with
* "result".
*/
nir_const_value min_idx[4];
for (unsigned i = 0; i < options->ballot_components; i++)
min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
nir_def *result_extended =
nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
}
/* Return a ballot-mask-sized value which represents "val" sign-extended and
* then shifted left by "shift". Only particular values for "val" are
* supported, see below.
*
* This function assumes that `val << shift` will never span a ballot_bit_size
* word and that the high bit of val can be extended across the entire result.
* This is trivially satisfied for 0, 1, ~0, and ~1. However, it may also be
* fine for other values if the shift is guaranteed to be sufficiently
* aligned. One example is 0xf when the shift is known to be a multiple of 4.
*/
static nir_def *
build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
const nir_lower_subgroups_options *options)
{
/* First compute the result assuming one ballot component. */
nir_def *result =
nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
if (options->ballot_components == 1)
return result;
/* Fix up the result when there is > 1 component. The idea is that nir_ishl
* masks out the high bits of the shift value already, so in case there's
* more than one component the component which 1 would be shifted into
* already has the right value and all we have to do is fixup the other
* components. Components below it should always be 0, and components above
* it must be either 0 or ~0 because of the assert above. For example, if
* the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
* we'll feed 33 into ishl, which will mask it off to get 1, so we'll
* compute a single-component result of 2, which is correct for the second
* component, but the first component needs to be 0, which we get by
* comparing the high bits of the shift with 0 and selecting the original
* answer or 0 for the first component (and something similar with the
* second component). This idea is generalized here for any component count
*/
nir_const_value min_shift[4];
for (unsigned i = 0; i < options->ballot_components; i++)
min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
nir_const_value max_shift[4];
for (unsigned i = 0; i < options->ballot_components; i++)
max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
nir_bcsel(b, nir_ult(b, shift, min_shift_val),
nir_imm_intN_t(b, val >> 63, result->bit_size),
result),
nir_imm_intN_t(b, 0, result->bit_size));
}
static nir_def *
ballot_type_to_uint(nir_builder *b, nir_def *value,
const nir_lower_subgroups_options *options)
@ -752,58 +853,6 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options)
return instr->type == nir_instr_type_intrinsic;
}
/* Return a ballot-mask-sized value which represents "val" sign-extended and
* then shifted left by "shift". Only particular values for "val" are
* supported, see below.
*
* This function assumes that `val << shift` will never span a ballot_bit_size
* word and that the high bit of val can be extended across the entire result.
* This is trivially satisfied for 0, 1, ~0, and ~1. However, it may also be
* fine for other values if the shift is guaranteed to be sufficiently
* aligned. One example is 0xf when the shift is known to be a multiple of 4.
*/
static nir_def *
build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
const nir_lower_subgroups_options *options)
{
/* First compute the result assuming one ballot component. */
nir_def *result =
nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
if (options->ballot_components == 1)
return result;
/* Fix up the result when there is > 1 component. The idea is that nir_ishl
* masks out the high bits of the shift value already, so in case there's
* more than one component the component which 1 would be shifted into
* already has the right value and all we have to do is fixup the other
* components. Components below it should always be 0, and components above
* it must be either 0 or ~0 because of the assert above. For example, if
* the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
* we'll feed 33 into ishl, which will mask it off to get 1, so we'll
* compute a single-component result of 2, which is correct for the second
* component, but the first component needs to be 0, which we get by
* comparing the high bits of the shift with 0 and selecting the original
* answer or 0 for the first component (and something similar with the
* second component). This idea is generalized here for any component count
*/
nir_const_value min_shift[4];
for (unsigned i = 0; i < options->ballot_components; i++)
min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
nir_const_value max_shift[4];
for (unsigned i = 0; i < options->ballot_components; i++)
max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
nir_bcsel(b, nir_ult(b, shift, min_shift_val),
nir_imm_intN_t(b, val >> 63, result->bit_size),
result),
nir_imm_intN_t(b, 0, result->bit_size));
}
static nir_def *
build_subgroup_eq_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
@ -841,55 +890,6 @@ build_subgroup_quad_mask(nir_builder *b,
return build_ballot_imm_ishl(b, 0xf, quad_first_idx, options);
}
/* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
* 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
* above the subgroup size for the masks, but gt_mask and ge_mask make them 1
* so we have to "and" with this mask.
*/
static nir_def *
build_subgroup_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
{
nir_def *subgroup_size = nir_load_subgroup_size(b);
/* First compute the result assuming one ballot component. */
nir_def *result =
nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
nir_isub_imm(b, options->ballot_bit_size,
subgroup_size));
/* Since the subgroup size and ballot bitsize are both powers of two, there
* are two possible cases to consider:
*
* (1) The subgroup size is less than the ballot bitsize. We need to return
* "result" in the first component and 0 in every other component.
* (2) The subgroup size is a multiple of the ballot bitsize. We need to
* return ~0 if the subgroup size divided by the ballot bitsize is less
* than or equal to the index in the vector and 0 otherwise. For example,
* with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
* to return { ~0, ~0, 0, 0 }.
*
* In case (2) it turns out that "result" will be ~0, because
* "ballot_bit_size - subgroup_size" is also a multiple of
* "ballot_bit_size" and since nir_ushr masks the shift value it will
* shifted by 0. This means that the first component can just be "result"
* in all cases. The other components will also get the correct value in
* case (1) if we just use the rule in case (2), so we'll get the correct
* result if we just follow (2) and then replace the first component with
* "result".
*/
nir_const_value min_idx[4];
for (unsigned i = 0; i < options->ballot_components; i++)
min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
nir_def *result_extended =
nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
}
static nir_def *
build_quad_vote_any(nir_builder *b, nir_def *src,
const nir_lower_subgroups_options *options)