From 07260dc210a2be5e09d987fdc7e7318b5ed2ad0a Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Sat, 21 Feb 2026 11:39:48 +0100 Subject: [PATCH] nir/lower_subgroups: lower shuffles and bitwise reduce to 32bit before scalarizing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pack/unpack should be a lot faster than duplicating the subgroup op. No fossil-db changes, but multiple people complained about this to me. Reviewed-by: Daniel Schürmann Part-of: --- src/compiler/nir/nir_lower_subgroups.c | 51 +++++++++++++++++++------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index b3316045a94..bede202b9e6 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -219,19 +219,25 @@ uint_to_ballot_type(nir_builder *b, nir_def *value, } static nir_def * -lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) +lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin, bool bitcast_to_32bit) { /* This is safe to call on scalar things but it would be silly */ assert(intrin->def.num_components > 1); nir_def *value = intrin->src[0].ssa; + if ((value->bit_size == 8 || value->bit_size == 16) && bitcast_to_32bit) { + unsigned num32 = DIV_ROUND_UP(value->bit_size * value->num_components, 32); + value = nir_pad_vector(b, value, num32 * (32 / value->bit_size)); + value = nir_bitcast_vector(b, value, 32); + } + nir_def *reads[NIR_MAX_VEC_COMPONENTS]; - for (unsigned i = 0; i < intrin->num_components; i++) { + for (unsigned i = 0; i < value->num_components; i++) { nir_intrinsic_instr *chan_intrin = nir_intrinsic_instr_create(b->shader, intrin->intrinsic); nir_def_init(&chan_intrin->instr, &chan_intrin->def, 1, - intrin->def.bit_size); + value->bit_size); chan_intrin->num_components = 1; /* value */ @@ -249,7 +255,13 @@ lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) reads[i] = &chan_intrin->def; } - return nir_vec(b, reads, intrin->num_components); + value = nir_vec(b, reads, value->num_components); + + if (value->bit_size != intrin->def.bit_size) { + value = nir_bitcast_vector(b, value, intrin->def.bit_size); + value = nir_trim_vector(b, value, intrin->def.num_components); + } + return value; } static nir_def * @@ -1054,6 +1066,19 @@ lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin) nir_load_subgroup_invocation(b))); } +static bool +is_bitwise(nir_op op) +{ + switch (op) { + case nir_op_iand: + case nir_op_ior: + case nir_op_ixor: + return true; + default: + return false; + } +} + static nir_def * lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) { @@ -1104,7 +1129,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_read_invocation: if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) return lower_boolean_shuffle(b, intrin, options); @@ -1116,7 +1141,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_read_first_invocation: if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); if (options->lower_read_first_invocation) return lower_read_first_invocation(b, intrin); @@ -1278,7 +1303,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1)) return lower_shuffle(b, intrin); else if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) return lower_boolean_shuffle(b, intrin, options); else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) @@ -1291,7 +1316,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1)) return lower_to_shuffle(b, intrin, options); else if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) return lower_boolean_shuffle(b, intrin, options); else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) @@ -1308,7 +1333,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) !nir_src_is_const(intrin->src[1]))) return lower_dynamic_quad_broadcast(b, intrin, options); else if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); break; case nir_intrinsic_quad_vote_any: @@ -1334,7 +1359,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) if (nir_intrinsic_cluster_size(intrin) == 1) return intrin->src[0].ssa; if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, is_bitwise(nir_intrinsic_reduction_op(intrin))); if (intrin->def.bit_size == 1 && options->ballot_components == 1 && (options->lower_boolean_reduce || options->lower_reduce)) return lower_boolean_reduce(b, intrin, options); @@ -1345,7 +1370,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_inclusive_scan: case nir_intrinsic_exclusive_scan: if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, is_bitwise(nir_intrinsic_reduction_op(intrin))); if (intrin->def.bit_size == 1 && options->ballot_components == 1 && (options->lower_boolean_reduce || options->lower_reduce)) return lower_boolean_reduce(b, intrin, options); @@ -1362,7 +1387,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) (!options->lower_boolean_shuffle || intrin->src[0].ssa->bit_size != 1)) return lower_to_shuffle(b, intrin, options); else if (options->lower_to_scalar && intrin->num_components > 1) - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); else if (options->lower_boolean_shuffle && intrin->src[0].ssa->bit_size == 1) return lower_boolean_shuffle(b, intrin, options); else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) @@ -1370,7 +1395,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) break; case nir_intrinsic_masked_swizzle_amd: if (options->lower_to_scalar && intrin->num_components > 1) { - return lower_subgroup_op_to_scalar(b, intrin); + return lower_subgroup_op_to_scalar(b, intrin, true); } else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64) { return lower_subgroup_op_to_32bit(b, intrin); }