nir: Add support for lowering shuffle to a waterfall loop

Qualcomm doesn't natively support shuffle, but it does natively support
relative shuffles where the delta is a constant. Therefore we'll expose
emulated support for both. Add support for this emulation of
subgroupShuffle() to NIR.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-by: Danylo Piliaiev <dpiliaiev@igalia.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14412>
This commit is contained in:
Connor Abbott 2022-01-04 15:44:31 +01:00 committed by Marge Bot
parent 913bec10c4
commit 503a5bae59
2 changed files with 81 additions and 1 deletions

View file

@ -4682,6 +4682,7 @@ typedef struct nir_lower_subgroups_options {
bool lower_relative_shuffle:1;
bool lower_shuffle_to_32bit:1;
bool lower_shuffle_to_swizzle_amd:1;
bool lower_shuffle:1;
bool lower_quad:1;
bool lower_quad_broadcast_dynamic:1;
bool lower_quad_broadcast_dynamic_to_const:1;

View file

@ -305,6 +305,83 @@ lower_to_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
}
}
static const struct glsl_type *
glsl_type_for_ssa(nir_ssa_def *def)
{
const struct glsl_type *comp_type = def->bit_size == 1 ? glsl_bool_type() :
glsl_uintN_t_type(def->bit_size);
return glsl_replace_vector_type(comp_type, def->num_components);
}
/* Lower nir_intrinsic_shuffle to a waterfall loop + nir_read_invocation.
*/
static nir_ssa_def *
lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin)
{
assert(intrin->src[0].is_ssa);
assert(intrin->src[1].is_ssa);
nir_ssa_def *val = intrin->src[0].ssa;
nir_ssa_def *id = intrin->src[1].ssa;
/* The loop is something like:
*
* while (true) {
* first_id = readFirstInvocation(gl_SubgroupInvocationID);
* first_val = readFirstInvocation(val);
* first_result = readInvocation(val, readFirstInvocation(id));
* if (id == first_id)
* result = first_val;
* if (elect()) {
* if (id > gl_SubgroupInvocationID) {
* result = first_result;
* }
* break;
* }
* }
*
* The idea is to guarantee, on each iteration of the loop, that anything
* reading from first_id gets the correct value, so that we can then kill
* it off by breaking out of the loop. Before doing that we also have to
* ensure that first_id invocation gets the correct value. It only won't be
* assigned the correct value already if the invocation it's reading from
* isn't already killed off, that is, if it's later than its own ID.
* Invocations where id <= gl_SubgroupInvocationID will be assigned their
* result in the first if, and invocations where id >
* gl_SubgroupInvocationID will be assigned their result in the second if.
*
* We do this more complicated loop rather than looping over all id's
* explicitly because at this point we don't know the "actual" subgroup
* size and at the moment there's no way to get at it, which means we may
* loop over always-inactive invocations.
*/
nir_ssa_def *subgroup_id = nir_load_subgroup_invocation(b);
nir_variable *result =
nir_local_variable_create(b->impl, glsl_type_for_ssa(val), "result");
nir_loop *loop = nir_push_loop(b); {
nir_ssa_def *first_id = nir_read_first_invocation(b, subgroup_id);
nir_ssa_def *first_val = nir_read_first_invocation(b, val);
nir_ssa_def *first_result =
nir_read_invocation(b, val, nir_read_first_invocation(b, id));
nir_if *nif = nir_push_if(b, nir_ieq(b, id, first_id)); {
nir_store_var(b, result, first_val, BITFIELD_MASK(val->num_components));
} nir_pop_if(b, nif);
nir_if *nif2 = nir_push_if(b, nir_elect(b, 1)); {
nir_if *nif3 = nir_push_if(b, nir_ult(b, subgroup_id, id)); {
nir_store_var(b, result, first_result, BITFIELD_MASK(val->num_components));
} nir_pop_if(b, nif3);
nir_jump(b, nir_jump_break);
} nir_pop_if(b, nif2);
} nir_pop_loop(b, loop);
return nir_load_var(b, result);
}
static bool
lower_subgroups_filter(const nir_instr *instr, const void *_options)
{
@ -702,7 +779,9 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
}
case nir_intrinsic_shuffle:
if (options->lower_to_scalar && intrin->num_components > 1)
if (options->lower_shuffle)
return lower_shuffle(b, intrin);
else if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
return lower_subgroup_op_to_32bit(b, intrin);