mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-05 05:18:08 +02:00
nir/lower_subgroups: move up some helper functions
build_subgroup_mask and build_ballot_imm_ishl will be needed by other functions higher-up the file. Signed-off-by: Job Noorman <jnoorman@igalia.com> Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io> Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31587>
This commit is contained in:
parent
085e7e419d
commit
e0cb4a94a3
1 changed files with 101 additions and 101 deletions
|
|
@ -62,6 +62,107 @@ lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
|
|||
return nir_pack_64_2x32_split(b, &intr_x->def, &intr_y->def);
|
||||
}
|
||||
|
||||
/* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
|
||||
* 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
|
||||
* above the subgroup size for the masks, but gt_mask and ge_mask make them 1
|
||||
* so we have to "and" with this mask.
|
||||
*/
|
||||
static nir_def *
|
||||
build_subgroup_mask(nir_builder *b,
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
nir_def *subgroup_size = nir_load_subgroup_size(b);
|
||||
|
||||
/* First compute the result assuming one ballot component. */
|
||||
nir_def *result =
|
||||
nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
|
||||
nir_isub_imm(b, options->ballot_bit_size,
|
||||
subgroup_size));
|
||||
|
||||
/* Since the subgroup size and ballot bitsize are both powers of two, there
|
||||
* are two possible cases to consider:
|
||||
*
|
||||
* (1) The subgroup size is less than the ballot bitsize. We need to return
|
||||
* "result" in the first component and 0 in every other component.
|
||||
* (2) The subgroup size is a multiple of the ballot bitsize. We need to
|
||||
* return ~0 if the subgroup size divided by the ballot bitsize is less
|
||||
* than or equal to the index in the vector and 0 otherwise. For example,
|
||||
* with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
|
||||
* to return { ~0, ~0, 0, 0 }.
|
||||
*
|
||||
* In case (2) it turns out that "result" will be ~0, because
|
||||
* "ballot_bit_size - subgroup_size" is also a multiple of
|
||||
* "ballot_bit_size" and since nir_ushr masks the shift value it will
|
||||
* shifted by 0. This means that the first component can just be "result"
|
||||
* in all cases. The other components will also get the correct value in
|
||||
* case (1) if we just use the rule in case (2), so we'll get the correct
|
||||
* result if we just follow (2) and then replace the first component with
|
||||
* "result".
|
||||
*/
|
||||
nir_const_value min_idx[4];
|
||||
for (unsigned i = 0; i < options->ballot_components; i++)
|
||||
min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
|
||||
nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
|
||||
|
||||
nir_def *result_extended =
|
||||
nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
|
||||
|
||||
return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
|
||||
result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
|
||||
}
|
||||
|
||||
/* Return a ballot-mask-sized value which represents "val" sign-extended and
|
||||
* then shifted left by "shift". Only particular values for "val" are
|
||||
* supported, see below.
|
||||
*
|
||||
* This function assumes that `val << shift` will never span a ballot_bit_size
|
||||
* word and that the high bit of val can be extended across the entire result.
|
||||
* This is trivially satisfied for 0, 1, ~0, and ~1. However, it may also be
|
||||
* fine for other values if the shift is guaranteed to be sufficiently
|
||||
* aligned. One example is 0xf when the shift is known to be a multiple of 4.
|
||||
*/
|
||||
static nir_def *
|
||||
build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
/* First compute the result assuming one ballot component. */
|
||||
nir_def *result =
|
||||
nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
|
||||
|
||||
if (options->ballot_components == 1)
|
||||
return result;
|
||||
|
||||
/* Fix up the result when there is > 1 component. The idea is that nir_ishl
|
||||
* masks out the high bits of the shift value already, so in case there's
|
||||
* more than one component the component which 1 would be shifted into
|
||||
* already has the right value and all we have to do is fixup the other
|
||||
* components. Components below it should always be 0, and components above
|
||||
* it must be either 0 or ~0 because of the assert above. For example, if
|
||||
* the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
|
||||
* we'll feed 33 into ishl, which will mask it off to get 1, so we'll
|
||||
* compute a single-component result of 2, which is correct for the second
|
||||
* component, but the first component needs to be 0, which we get by
|
||||
* comparing the high bits of the shift with 0 and selecting the original
|
||||
* answer or 0 for the first component (and something similar with the
|
||||
* second component). This idea is generalized here for any component count
|
||||
*/
|
||||
nir_const_value min_shift[4];
|
||||
for (unsigned i = 0; i < options->ballot_components; i++)
|
||||
min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
|
||||
nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
|
||||
|
||||
nir_const_value max_shift[4];
|
||||
for (unsigned i = 0; i < options->ballot_components; i++)
|
||||
max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
|
||||
nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
|
||||
|
||||
return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
|
||||
nir_bcsel(b, nir_ult(b, shift, min_shift_val),
|
||||
nir_imm_intN_t(b, val >> 63, result->bit_size),
|
||||
result),
|
||||
nir_imm_intN_t(b, 0, result->bit_size));
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
ballot_type_to_uint(nir_builder *b, nir_def *value,
|
||||
const nir_lower_subgroups_options *options)
|
||||
|
|
@ -752,58 +853,6 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options)
|
|||
return instr->type == nir_instr_type_intrinsic;
|
||||
}
|
||||
|
||||
/* Return a ballot-mask-sized value which represents "val" sign-extended and
|
||||
* then shifted left by "shift". Only particular values for "val" are
|
||||
* supported, see below.
|
||||
*
|
||||
* This function assumes that `val << shift` will never span a ballot_bit_size
|
||||
* word and that the high bit of val can be extended across the entire result.
|
||||
* This is trivially satisfied for 0, 1, ~0, and ~1. However, it may also be
|
||||
* fine for other values if the shift is guaranteed to be sufficiently
|
||||
* aligned. One example is 0xf when the shift is known to be a multiple of 4.
|
||||
*/
|
||||
static nir_def *
|
||||
build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_def *shift,
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
/* First compute the result assuming one ballot component. */
|
||||
nir_def *result =
|
||||
nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
|
||||
|
||||
if (options->ballot_components == 1)
|
||||
return result;
|
||||
|
||||
/* Fix up the result when there is > 1 component. The idea is that nir_ishl
|
||||
* masks out the high bits of the shift value already, so in case there's
|
||||
* more than one component the component which 1 would be shifted into
|
||||
* already has the right value and all we have to do is fixup the other
|
||||
* components. Components below it should always be 0, and components above
|
||||
* it must be either 0 or ~0 because of the assert above. For example, if
|
||||
* the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
|
||||
* we'll feed 33 into ishl, which will mask it off to get 1, so we'll
|
||||
* compute a single-component result of 2, which is correct for the second
|
||||
* component, but the first component needs to be 0, which we get by
|
||||
* comparing the high bits of the shift with 0 and selecting the original
|
||||
* answer or 0 for the first component (and something similar with the
|
||||
* second component). This idea is generalized here for any component count
|
||||
*/
|
||||
nir_const_value min_shift[4];
|
||||
for (unsigned i = 0; i < options->ballot_components; i++)
|
||||
min_shift[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
|
||||
nir_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
|
||||
|
||||
nir_const_value max_shift[4];
|
||||
for (unsigned i = 0; i < options->ballot_components; i++)
|
||||
max_shift[i] = nir_const_value_for_int((i + 1) * options->ballot_bit_size, 32);
|
||||
nir_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
|
||||
|
||||
return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
|
||||
nir_bcsel(b, nir_ult(b, shift, min_shift_val),
|
||||
nir_imm_intN_t(b, val >> 63, result->bit_size),
|
||||
result),
|
||||
nir_imm_intN_t(b, 0, result->bit_size));
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
build_subgroup_eq_mask(nir_builder *b,
|
||||
const nir_lower_subgroups_options *options)
|
||||
|
|
@ -841,55 +890,6 @@ build_subgroup_quad_mask(nir_builder *b,
|
|||
return build_ballot_imm_ishl(b, 0xf, quad_first_idx, options);
|
||||
}
|
||||
|
||||
/* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
|
||||
* 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
|
||||
* above the subgroup size for the masks, but gt_mask and ge_mask make them 1
|
||||
* so we have to "and" with this mask.
|
||||
*/
|
||||
static nir_def *
|
||||
build_subgroup_mask(nir_builder *b,
|
||||
const nir_lower_subgroups_options *options)
|
||||
{
|
||||
nir_def *subgroup_size = nir_load_subgroup_size(b);
|
||||
|
||||
/* First compute the result assuming one ballot component. */
|
||||
nir_def *result =
|
||||
nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
|
||||
nir_isub_imm(b, options->ballot_bit_size,
|
||||
subgroup_size));
|
||||
|
||||
/* Since the subgroup size and ballot bitsize are both powers of two, there
|
||||
* are two possible cases to consider:
|
||||
*
|
||||
* (1) The subgroup size is less than the ballot bitsize. We need to return
|
||||
* "result" in the first component and 0 in every other component.
|
||||
* (2) The subgroup size is a multiple of the ballot bitsize. We need to
|
||||
* return ~0 if the subgroup size divided by the ballot bitsize is less
|
||||
* than or equal to the index in the vector and 0 otherwise. For example,
|
||||
* with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
|
||||
* to return { ~0, ~0, 0, 0 }.
|
||||
*
|
||||
* In case (2) it turns out that "result" will be ~0, because
|
||||
* "ballot_bit_size - subgroup_size" is also a multiple of
|
||||
* "ballot_bit_size" and since nir_ushr masks the shift value it will
|
||||
* shifted by 0. This means that the first component can just be "result"
|
||||
* in all cases. The other components will also get the correct value in
|
||||
* case (1) if we just use the rule in case (2), so we'll get the correct
|
||||
* result if we just follow (2) and then replace the first component with
|
||||
* "result".
|
||||
*/
|
||||
nir_const_value min_idx[4];
|
||||
for (unsigned i = 0; i < options->ballot_components; i++)
|
||||
min_idx[i] = nir_const_value_for_int(i * options->ballot_bit_size, 32);
|
||||
nir_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
|
||||
|
||||
nir_def *result_extended =
|
||||
nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
|
||||
|
||||
return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
|
||||
result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
build_quad_vote_any(nir_builder *b, nir_def *src,
|
||||
const nir_lower_subgroups_options *options)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue