From d641750573a163674458004f74d0057bec86c62e Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Wed, 17 Jan 2024 14:12:43 +0100 Subject: [PATCH] nir: add lowering for boolean shuffle Reviewed-by: Rhys Perry Part-of: --- src/compiler/nir/nir.h | 1 + src/compiler/nir/nir_lower_subgroups.c | 96 +++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 4c29e35c12f..96386c53e5a 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5635,6 +5635,7 @@ typedef struct nir_lower_subgroups_options { bool lower_ballot_bit_count_to_mbcnt_amd : 1; bool lower_inverse_ballot : 1; bool lower_boolean_reduce : 1; + bool lower_boolean_shuffle : 1; } nir_lower_subgroups_options; bool nir_lower_subgroups(nir_shader *shader, diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 69e44dd236b..d35cac3818f 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -355,6 +355,84 @@ lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin) return nir_load_var(b, result); } +static nir_def * +lower_boolean_shuffle(nir_builder *b, nir_intrinsic_instr *intrin, + const nir_lower_subgroups_options *options) +{ + assert(options->ballot_components == 1 && options->subgroup_size); + nir_def *ballot = nir_ballot_relaxed(b, 1, options->ballot_bit_size, intrin->src[0].ssa); + + nir_def *index = NULL; + + /* If the shuffle amount isn't constant, it might be divergent but + * inverse_ballot requires a uniform source, so take a different path. + * rotate allows us to assume the delta is uniform unlike shuffle_up/down. + */ + switch (intrin->intrinsic) { + case nir_intrinsic_shuffle_up: + if (nir_src_is_const(intrin->src[1])) + ballot = nir_ishl(b, ballot, intrin->src[1].ssa); + else + index = nir_isub(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa); + break; + case nir_intrinsic_shuffle_down: + if (nir_src_is_const(intrin->src[1])) + ballot = nir_ushr(b, ballot, intrin->src[1].ssa); + else + index = nir_iadd(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa); + break; + case nir_intrinsic_shuffle_xor: + index = nir_ixor(b, nir_load_subgroup_invocation(b), intrin->src[1].ssa); + break; + case nir_intrinsic_rotate: { + nir_def *delta = nir_as_uniform(b, intrin->src[1].ssa); + uint32_t cluster_size = nir_intrinsic_cluster_size(intrin); + cluster_size = cluster_size ? cluster_size : options->subgroup_size; + cluster_size = MIN2(cluster_size, options->subgroup_size); + if (cluster_size == 1) { + return intrin->src[0].ssa; + } else if (cluster_size == 2) { + delta = nir_iand_imm(b, delta, cluster_size - 1); + nir_def *lo = nir_iand_imm(b, nir_ushr_imm(b, ballot, 1), 0x5555555555555555ull); + nir_def *hi = nir_iand_imm(b, nir_ishl_imm(b, ballot, 1), 0xaaaaaaaaaaaaaaaaull); + ballot = nir_bcsel(b, nir_ine_imm(b, delta, 0), nir_ior(b, hi, lo), ballot); + } else if (cluster_size == ballot->bit_size) { + ballot = nir_uror(b, ballot, delta); + } else if (cluster_size == 32) { + nir_def *unpacked = nir_unpack_64_2x32(b, ballot); + unpacked = nir_uror(b, unpacked, delta); + ballot = nir_pack_64_2x32(b, unpacked); + } else { + delta = nir_iand_imm(b, delta, cluster_size - 1); + nir_def *delta_rev = nir_isub_imm(b, cluster_size, delta); + nir_def *mask = nir_mask(b, delta_rev, ballot->bit_size); + for (uint32_t i = cluster_size; i < ballot->bit_size; i *= 2) { + mask = nir_ior(b, nir_ishl_imm(b, mask, i), mask); + } + nir_def *lo = nir_iand(b, nir_ushr(b, ballot, delta), mask); + nir_def *hi = nir_iand(b, nir_ishl(b, ballot, delta_rev), nir_inot(b, mask)); + ballot = nir_ior(b, lo, hi); + } + break; + } + case nir_intrinsic_shuffle: + index = intrin->src[1].ssa; + break; + case nir_intrinsic_read_invocation: + index = nir_as_uniform(b, intrin->src[1].ssa); + break; + default: + unreachable("not a boolean shuffle"); + } + + if (index) { + nir_def *mask = nir_ishl(b, nir_imm_intN_t(b, 1, ballot->bit_size), index); + return nir_ine_imm(b, nir_iand(b, ballot, mask), 0); + } else { + return nir_inverse_ballot(b, 1, ballot); + } +} + static nir_def * vec_bit_count(nir_builder *b, nir_def *value) { @@ -755,6 +833,9 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin); + if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) + return lower_boolean_shuffle(b, intrin, options); + if (options->lower_read_invocation_to_cond) return lower_read_invocation_to_cond(b, intrin); @@ -918,20 +999,26 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) } case nir_intrinsic_shuffle: - if (options->lower_shuffle) + if (options->lower_shuffle && + (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1)) return lower_shuffle(b, intrin); else if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin); + else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) + return lower_boolean_shuffle(b, intrin, options); else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) return lower_subgroup_op_to_32bit(b, intrin); break; case nir_intrinsic_shuffle_xor: case nir_intrinsic_shuffle_up: case nir_intrinsic_shuffle_down: - if (options->lower_relative_shuffle) + if (options->lower_relative_shuffle && + (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1)) return lower_to_shuffle(b, intrin, options); else if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin); + else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) + return lower_boolean_shuffle(b, intrin, options); else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) return lower_subgroup_op_to_32bit(b, intrin); break; @@ -975,10 +1062,13 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_rotate: if (nir_intrinsic_execution_scope(intrin) == SCOPE_SUBGROUP) { - if (options->lower_rotate_to_shuffle) + if (options->lower_rotate_to_shuffle && + (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1)) return lower_to_shuffle(b, intrin, options); else if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin); + else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) + return lower_boolean_shuffle(b, intrin, options); else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) return lower_subgroup_op_to_32bit(b, intrin); }