spirv: Lower ShuffleUpINTEL and ShuffleDownINTEL to intrinsics

Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40376>
This commit is contained in:
Caio Oliveira 2026-03-11 15:24:15 -07:00 committed by Marge Bot
parent dcba49d7ef
commit f07138f244

View file

@ -9,6 +9,7 @@ static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder *b,
nir_intrinsic_op nir_op,
struct vtn_ssa_value *src0,
struct vtn_ssa_value *src1,
nir_def *index,
unsigned const_idx0,
unsigned const_idx1)
@ -24,6 +25,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
vtn_assert(dst->type == src0->type);
vtn_assert(glsl_type_is_vector_or_scalar(dst->type));
if (src1)
vtn_assert(src1->type == src0->type);
nir_intrinsic_instr *intrin =
nir_intrinsic_instr_create(b->nb.shader, nir_op);
@ -31,8 +34,10 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
intrin->num_components = intrin->def.num_components;
intrin->src[0] = nir_src_for_ssa(src0->def);
if (src1)
intrin->src[1] = nir_src_for_ssa(src1->def);
if (index)
intrin->src[1] = nir_src_for_ssa(index);
intrin->src[src1 ? 2 : 1] = nir_src_for_ssa(index);
intrin->const_index[0] = const_idx0;
intrin->const_index[1] = const_idx1;
@ -140,7 +145,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
vtn_ssa_value(b, w[3 + has_scope]),
NULL, 0, 0));
NULL, NULL, 0, 0));
break;
}
@ -151,6 +156,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
vtn_ssa_value(b, w[3 + has_scope]),
NULL,
vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
break;
}
@ -250,6 +256,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
}
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
NULL,
vtn_get_nir_ssa(b, w[5]), 0, 0));
break;
}
@ -260,40 +267,21 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
NULL,
vtn_get_nir_ssa(b, w[4]), 0, 0));
break;
}
case SpvOpSubgroupShuffleUpINTEL:
case SpvOpSubgroupShuffleDownINTEL: {
/* TODO: Move this lower on the compiler stack, where we can move the
* current/other data to adjacent registers to avoid doing a shuffle
* twice.
*/
nir_builder *nb = &b->nb;
nir_def *size = nir_load_subgroup_size(nb);
nir_def *delta = vtn_get_nir_ssa(b, w[5]);
/* Rewrite UP in terms of DOWN.
*
* UP(a, b, delta) == DOWN(a, b, size - delta)
*/
if (opcode == SpvOpSubgroupShuffleUpINTEL)
delta = nir_isub(nb, size, delta);
nir_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
struct vtn_ssa_value *current =
vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
index, 0, 0);
struct vtn_ssa_value *next =
vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
nir_isub(nb, index, size), 0, 0);
nir_def *cond = nir_ilt(nb, index, size);
vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleUpINTEL ?
nir_intrinsic_shuffle_up_intel : nir_intrinsic_shuffle_down_intel;
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op,
vtn_ssa_value(b, w[3]),
vtn_ssa_value(b, w[4]),
vtn_get_nir_ssa(b, w[5]),
0, 0));
break;
}
@ -306,7 +294,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
struct vtn_ssa_value *delta = vtn_ssa_value(b, w[5]);
vtn_push_nir_ssa(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_rotate,
value, delta->def, cluster_size, 0)->def);
value, NULL, delta->def, cluster_size, 0)->def);
break;
}
@ -323,6 +311,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
vtn_ssa_value(b, w[4]),
NULL,
vtn_get_nir_ssa(b, w[5]), 0, 0));
break;
@ -346,7 +335,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
}
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, NULL,
0, 0));
break;
}
@ -495,7 +485,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
}
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL, NULL,
reduction_op, cluster_size));
break;
}