nir: add lowering for boolean shuffle

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27116>
This commit is contained in:
Georg Lehmann 2024-01-17 14:12:43 +01:00 committed by Marge Bot
parent 37a15ba53a
commit d641750573
2 changed files with 94 additions and 3 deletions

View file

@ -5635,6 +5635,7 @@ typedef struct nir_lower_subgroups_options {
bool lower_ballot_bit_count_to_mbcnt_amd : 1;
bool lower_inverse_ballot : 1;
bool lower_boolean_reduce : 1;
bool lower_boolean_shuffle : 1;
} nir_lower_subgroups_options;
bool nir_lower_subgroups(nir_shader *shader,

View file

@ -355,6 +355,84 @@ lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin)
return nir_load_var(b, result);
}
static nir_def *
lower_boolean_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
const nir_lower_subgroups_options *options)
{
assert(options->ballot_components == 1 && options->subgroup_size);
nir_def *ballot = nir_ballot_relaxed(b, 1, options->ballot_bit_size, intrin->src[0].ssa);
nir_def *index = NULL;
/* If the shuffle amount isn't constant, it might be divergent but
* inverse_ballot requires a uniform source, so take a different path.
* rotate allows us to assume the delta is uniform unlike shuffle_up/down.
*/
switch (intrin->intrinsic) {
case nir_intrinsic_shuffle_up:
if (nir_src_is_const(intrin->src[1]))
ballot = nir_ishl(b, ballot, intrin->src[1].ssa);
else
index = nir_isub(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
break;
case nir_intrinsic_shuffle_down:
if (nir_src_is_const(intrin->src[1]))
ballot = nir_ushr(b, ballot, intrin->src[1].ssa);
else
index = nir_iadd(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
break;
case nir_intrinsic_shuffle_xor:
index = nir_ixor(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa);
break;
case nir_intrinsic_rotate: {
nir_def *delta = nir_as_uniform(b, intrin->src[1].ssa);
uint32_t cluster_size = nir_intrinsic_cluster_size(intrin);
cluster_size = cluster_size ? cluster_size : options->subgroup_size;
cluster_size = MIN2(cluster_size, options->subgroup_size);
if (cluster_size == 1) {
return intrin->src[0].ssa;
} else if (cluster_size == 2) {
delta = nir_iand_imm(b, delta, cluster_size - 1);
nir_def *lo = nir_iand_imm(b, nir_ushr_imm(b, ballot, 1), 0x5555555555555555ull);
nir_def *hi = nir_iand_imm(b, nir_ishl_imm(b, ballot, 1), 0xaaaaaaaaaaaaaaaaull);
ballot = nir_bcsel(b, nir_ine_imm(b, delta, 0), nir_ior(b, hi, lo), ballot);
} else if (cluster_size == ballot->bit_size) {
ballot = nir_uror(b, ballot, delta);
} else if (cluster_size == 32) {
nir_def *unpacked = nir_unpack_64_2x32(b, ballot);
unpacked = nir_uror(b, unpacked, delta);
ballot = nir_pack_64_2x32(b, unpacked);
} else {
delta = nir_iand_imm(b, delta, cluster_size - 1);
nir_def *delta_rev = nir_isub_imm(b, cluster_size, delta);
nir_def *mask = nir_mask(b, delta_rev, ballot->bit_size);
for (uint32_t i = cluster_size; i < ballot->bit_size; i *= 2) {
mask = nir_ior(b, nir_ishl_imm(b, mask, i), mask);
}
nir_def *lo = nir_iand(b, nir_ushr(b, ballot, delta), mask);
nir_def *hi = nir_iand(b, nir_ishl(b, ballot, delta_rev), nir_inot(b, mask));
ballot = nir_ior(b, lo, hi);
}
break;
}
case nir_intrinsic_shuffle:
index = intrin->src[1].ssa;
break;
case nir_intrinsic_read_invocation:
index = nir_as_uniform(b, intrin->src[1].ssa);
break;
default:
unreachable("not a boolean shuffle");
}
if (index) {
nir_def *mask = nir_ishl(b, nir_imm_intN_t(b, 1, ballot->bit_size), index);
return nir_ine_imm(b, nir_iand(b, ballot, mask), 0);
} else {
return nir_inverse_ballot(b, 1, ballot);
}
}
static nir_def *
vec_bit_count(nir_builder *b, nir_def *value)
{
@ -755,6 +833,9 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin);
if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
return lower_boolean_shuffle(b, intrin, options);
if (options->lower_read_invocation_to_cond)
return lower_read_invocation_to_cond(b, intrin);
@ -918,20 +999,26 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
}
case nir_intrinsic_shuffle:
if (options->lower_shuffle)
if (options->lower_shuffle &&
(!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
return lower_shuffle(b, intrin);
else if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin);
else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
return lower_boolean_shuffle(b, intrin, options);
else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
return lower_subgroup_op_to_32bit(b, intrin);
break;
case nir_intrinsic_shuffle_xor:
case nir_intrinsic_shuffle_up:
case nir_intrinsic_shuffle_down:
if (options->lower_relative_shuffle)
if (options->lower_relative_shuffle &&
(!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
return lower_to_shuffle(b, intrin, options);
else if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin);
else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
return lower_boolean_shuffle(b, intrin, options);
else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
return lower_subgroup_op_to_32bit(b, intrin);
break;
@ -975,10 +1062,13 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
case nir_intrinsic_rotate:
if (nir_intrinsic_execution_scope(intrin) == SCOPE_SUBGROUP) {
if (options->lower_rotate_to_shuffle)
if (options->lower_rotate_to_shuffle &&
(!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1))
return lower_to_shuffle(b, intrin, options);
else if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin);
else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1)
return lower_boolean_shuffle(b, intrin, options);
else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
return lower_subgroup_op_to_32bit(b, intrin);
}