From f07138f2444ce865327fa18501fe8330e520bf12 Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Wed, 11 Mar 2026 15:24:15 -0700 Subject: [PATCH] spirv: Lower ShuffleUpINTEL and ShuffleDownINTEL to intrinsics Reviewed-by: Lionel Landwerlin Part-of: --- src/compiler/spirv/vtn_subgroup.c | 56 +++++++++++++------------------ 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index 32ad555f4cb..a1af8809335 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -9,6 +9,7 @@ static struct vtn_ssa_value * vtn_build_subgroup_instr(struct vtn_builder *b, nir_intrinsic_op nir_op, struct vtn_ssa_value *src0, + struct vtn_ssa_value *src1, nir_def *index, unsigned const_idx0, unsigned const_idx1) @@ -24,6 +25,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b, vtn_assert(dst->type == src0->type); vtn_assert(glsl_type_is_vector_or_scalar(dst->type)); + if (src1) + vtn_assert(src1->type == src0->type); nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->nb.shader, nir_op); @@ -31,8 +34,10 @@ vtn_build_subgroup_instr(struct vtn_builder *b, intrin->num_components = intrin->def.num_components; intrin->src[0] = nir_src_for_ssa(src0->def); + if (src1) + intrin->src[1] = nir_src_for_ssa(src1->def); if (index) - intrin->src[1] = nir_src_for_ssa(index); + intrin->src[src1 ? 2 : 1] = nir_src_for_ssa(index); intrin->const_index[0] = const_idx0; intrin->const_index[1] = const_idx1; @@ -140,7 +145,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, vtn_push_ssa_value(b, w[2], vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, vtn_ssa_value(b, w[3 + has_scope]), - NULL, 0, 0)); + NULL, NULL, 0, 0)); break; } @@ -151,6 +156,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, vtn_push_ssa_value(b, w[2], vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, vtn_ssa_value(b, w[3 + has_scope]), + NULL, vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0)); break; } @@ -250,6 +256,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, } vtn_push_ssa_value(b, w[2], vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), + NULL, vtn_get_nir_ssa(b, w[5]), 0, 0)); break; } @@ -260,40 +267,21 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor; vtn_push_ssa_value(b, w[2], vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]), + NULL, vtn_get_nir_ssa(b, w[4]), 0, 0)); break; } case SpvOpSubgroupShuffleUpINTEL: case SpvOpSubgroupShuffleDownINTEL: { - /* TODO: Move this lower on the compiler stack, where we can move the - * current/other data to adjacent registers to avoid doing a shuffle - * twice. - */ - - nir_builder *nb = &b->nb; - nir_def *size = nir_load_subgroup_size(nb); - nir_def *delta = vtn_get_nir_ssa(b, w[5]); - - /* Rewrite UP in terms of DOWN. - * - * UP(a, b, delta) == DOWN(a, b, size - delta) - */ - if (opcode == SpvOpSubgroupShuffleUpINTEL) - delta = nir_isub(nb, size, delta); - - nir_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta); - struct vtn_ssa_value *current = - vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]), - index, 0, 0); - - struct vtn_ssa_value *next = - vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]), - nir_isub(nb, index, size), 0, 0); - - nir_def *cond = nir_ilt(nb, index, size); - vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def)); - + nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleUpINTEL ? + nir_intrinsic_shuffle_up_intel : nir_intrinsic_shuffle_down_intel; + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, op, + vtn_ssa_value(b, w[3]), + vtn_ssa_value(b, w[4]), + vtn_get_nir_ssa(b, w[5]), + 0, 0)); break; } @@ -306,7 +294,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, struct vtn_ssa_value *delta = vtn_ssa_value(b, w[5]); vtn_push_nir_ssa(b, w[2], vtn_build_subgroup_instr(b, nir_intrinsic_rotate, - value, delta->def, cluster_size, 0)->def); + value, NULL, delta->def, cluster_size, 0)->def); break; } @@ -323,6 +311,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, vtn_push_ssa_value(b, w[2], vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, vtn_ssa_value(b, w[4]), + NULL, vtn_get_nir_ssa(b, w[5]), 0, 0)); break; @@ -346,7 +335,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap"); } vtn_push_ssa_value(b, w[2], - vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0)); + vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, NULL, + 0, 0)); break; } @@ -495,7 +485,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, } vtn_push_ssa_value(b, w[2], - vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL, + vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL, NULL, reduction_op, cluster_size)); break; }