zink: implement ops for KHR_shader_subgroup

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31228>
This commit is contained in:
Mike Blumenkrantz 2025-01-24 07:33:24 -05:00 committed by Marge Bot
parent b4f3136fea
commit cee77ba6ec
3 changed files with 236 additions and 1 deletions

View file

@ -109,7 +109,8 @@ struct ntv_context {
subgroup_invocation_var,
subgroup_le_mask_var,
subgroup_lt_mask_var,
subgroup_size_var;
subgroup_size_var,
num_subgroups_var;
SpvId discard_func;
SpvId float_array_type[2];
@ -3250,6 +3251,169 @@ emit_derivative(struct ntv_context *ctx, nir_intrinsic_instr *intr)
store_def(ctx, intr->def.index, result, nir_type_float);
}
static void
emit_subgroup(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
SpvOp op;
switch (nir_intrinsic_reduction_op(intr)) {
#define SUBGROUP_CASE(nir, spirv) \
case nir_op_##nir: \
op = SpvOpGroupNonUniform##spirv; \
break
SUBGROUP_CASE(iadd, IAdd);
SUBGROUP_CASE(fadd, FAdd);
SUBGROUP_CASE(imul, IMul);
SUBGROUP_CASE(fmul, FMul);
SUBGROUP_CASE(imin, SMin);
SUBGROUP_CASE(umin, UMin);
SUBGROUP_CASE(fmin, FMin);
SUBGROUP_CASE(imax, SMax);
SUBGROUP_CASE(umax, UMax);
SUBGROUP_CASE(fmax, FMax);
#undef SUBGROUP_CASE
#define SUBGROUP_CASE_LOGICAL(nir, spirv) \
case nir_op_##nir: \
op = intr->src[0].ssa->bit_size != 1 ? SpvOpGroupNonUniformBitwise##spirv : SpvOpGroupNonUniformLogical##spirv; \
break
SUBGROUP_CASE_LOGICAL(iand, And);
SUBGROUP_CASE_LOGICAL(ior, Or);
SUBGROUP_CASE_LOGICAL(ixor, Xor);
#undef SUBGROUP_CASE_LOGICAL
default:
fprintf(stderr, "emit_subgroup: reduction op not implemented (%s)\n",
nir_intrinsic_infos[nir_intrinsic_reduction_op(intr)].name);
unreachable("unhandled intrinsic");
}
SpvGroupOperation groupop;
unsigned cluster_size = 0;
switch (intr->intrinsic) {
case nir_intrinsic_reduce:
cluster_size = nir_intrinsic_cluster_size(intr);
groupop = cluster_size ? SpvGroupOperationClusteredReduce : SpvGroupOperationReduce;
break;
case nir_intrinsic_inclusive_scan:
groupop = SpvGroupOperationInclusiveScan;
break;
case nir_intrinsic_exclusive_scan:
groupop = SpvGroupOperationExclusiveScan;
break;
default:
fprintf(stderr, "emit_subgroup: not implemented (%s)\n",
nir_intrinsic_infos[intr->intrinsic].name);
unreachable("unhandled intrinsic");
}
spirv_builder_emit_cap(&ctx->builder, cluster_size ? SpvCapabilityGroupNonUniformClustered : SpvCapabilityGroupNonUniformArithmetic);
nir_alu_type atype;
SpvId src0 = get_src(ctx, &intr->src[0], &atype);
switch (op) {
case SpvOpGroupNonUniformFAdd:
case SpvOpGroupNonUniformFMul:
case SpvOpGroupNonUniformFMin:
case SpvOpGroupNonUniformFMax:
atype = nir_type_float;
src0 = emit_bitcast(ctx, get_def_type(ctx, intr->src[0].ssa, atype), src0);
break;
default: break;
}
SpvId type = get_def_type(ctx, intr->src[0].ssa, atype);
SpvId result = 0;
if (cluster_size)
result = spirv_builder_emit_triop_subgroup(&ctx->builder, op, type, groupop, src0, spirv_builder_const_uint(&ctx->builder, 32, cluster_size));
else
result = spirv_builder_emit_binop_subgroup(&ctx->builder, op, type, groupop, src0);
store_def(ctx, intr->def.index, result, atype);
}
static void
emit_subgroup_quad(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
SpvOp op;
nir_alu_type atype, itype;
SpvId src0 = get_src(ctx, &intr->src[0], &atype);
SpvId src1 = 0;
enum {
QUAD_SWAP_HORIZONTAL,
QUAD_SWAP_VERTICAL,
QUAD_SWAP_DIAGONAL,
};
switch (intr->intrinsic) {
case nir_intrinsic_quad_broadcast:
op = SpvOpGroupNonUniformQuadBroadcast;
src1 = get_src(ctx, &intr->src[1], &itype);
if (itype != nir_type_uint)
src1 = emit_bitcast(ctx, get_def_type(ctx, intr->src[1].ssa, nir_type_uint), src1);
break;
case nir_intrinsic_quad_swap_horizontal:
op = SpvOpGroupNonUniformQuadSwap;
src1 = spirv_builder_const_uint(&ctx->builder, 32, QUAD_SWAP_HORIZONTAL);
break;
case nir_intrinsic_quad_swap_vertical:
op = SpvOpGroupNonUniformQuadSwap;
src1 = spirv_builder_const_uint(&ctx->builder, 32, QUAD_SWAP_VERTICAL);
break;
case nir_intrinsic_quad_swap_diagonal:
op = SpvOpGroupNonUniformQuadSwap;
src1 = spirv_builder_const_uint(&ctx->builder, 32, QUAD_SWAP_DIAGONAL);
break;
default:
fprintf(stderr, "emit_subgroup_quad: not implemented (%s)\n",
nir_intrinsic_infos[intr->intrinsic].name);
unreachable("unhandled intrinsic");
}
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformQuad);
SpvId result = spirv_builder_emit_binop_subgroup(&ctx->builder, op, get_def_type(ctx, intr->src[0].ssa, atype), src0, src1);
store_def(ctx, intr->def.index, result, atype);
}
static void
emit_shuffle(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
SpvOp op;
switch (intr->intrinsic) {
case nir_intrinsic_shuffle:
op = SpvOpGroupNonUniformShuffle;
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffle);
break;
case nir_intrinsic_shuffle_xor:
op = SpvOpGroupNonUniformShuffleXor;
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffle);
break;
case nir_intrinsic_shuffle_up:
op = SpvOpGroupNonUniformShuffleUp;
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffleRelative);
break;
case nir_intrinsic_shuffle_down:
op = SpvOpGroupNonUniformShuffleDown;
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffleRelative);
break;
default:
fprintf(stderr, "emit_shuffle: not implemented (%s)\n",
nir_intrinsic_infos[intr->intrinsic].name);
unreachable("unhandled intrinsic");
}
nir_alu_type atype, unused;
SpvId src0 = get_src(ctx, &intr->src[0], &atype);
SpvId src1 = get_src(ctx, &intr->src[1], &unused);
SpvId result = spirv_builder_emit_binop_subgroup(&ctx->builder, op, get_def_type(ctx, intr->src[0].ssa, atype), src0, src1);
store_def(ctx, intr->def.index, result, atype);
}
static void
emit_elect(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniform);
SpvId result = spirv_builder_emit_unop_const(&ctx->builder, SpvOpGroupNonUniformElect, spirv_builder_type_bool(&ctx->builder), SpvScopeSubgroup);
store_def(ctx, intr->def.index, result, nir_type_bool);
}
static void
emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
@ -3463,6 +3627,7 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
LOAD_SHADER_BALLOT(subgroup_le_mask, SubgroupLeMask);
LOAD_SHADER_BALLOT(subgroup_lt_mask, SubgroupLtMask);
LOAD_SHADER_BALLOT(subgroup_size, SubgroupSize);
LOAD_SHADER_BALLOT(num_subgroups, NumSubgroups);
case nir_intrinsic_ballot:
emit_ballot(ctx, intr);
@ -3525,6 +3690,30 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
emit_derivative(ctx, intr);
break;
case nir_intrinsic_reduce:
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_exclusive_scan:
emit_subgroup(ctx, intr);
break;
case nir_intrinsic_quad_broadcast:
case nir_intrinsic_quad_swap_horizontal:
case nir_intrinsic_quad_swap_vertical:
case nir_intrinsic_quad_swap_diagonal:
emit_subgroup_quad(ctx, intr);
break;
case nir_intrinsic_shuffle:
case nir_intrinsic_shuffle_xor:
case nir_intrinsic_shuffle_up:
case nir_intrinsic_shuffle_down:
emit_shuffle(ctx, intr);
break;
case nir_intrinsic_elect:
emit_elect(ctx, intr);
break;
default:
fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
nir_intrinsic_infos[intr->intrinsic].name);
@ -4887,6 +5076,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, const s
"main", ctx.entry_ifaces,
ctx.num_entry_ifaces);
if (ctx.num_subgroups_var)
spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGroupNonUniform);
size_t num_words = spirv_builder_get_num_words(&ctx.builder);
ret = ralloc(NULL, struct spirv_shader);

View file

@ -595,6 +595,23 @@ spirv_builder_emit_binop(struct spirv_builder *b, SpvOp op, SpvId result_type,
return result;
}
SpvId
spirv_builder_emit_binop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1)
{
struct spirv_buffer *buf = op == SpvOpSpecConstantOp ? &b->types_const_defs : &b->instructions;
SpvId result = spirv_builder_new_id(b);
spirv_buffer_prepare(buf, b->mem_ctx, 6);
spirv_buffer_emit_word(buf, op | (6 << 16));
spirv_buffer_emit_word(buf, result_type);
spirv_buffer_emit_word(buf, result);
spirv_buffer_emit_word(buf, spirv_builder_const_uint(b, 32, SpvScopeSubgroup));
spirv_buffer_emit_word(buf, operand0);
spirv_buffer_emit_word(buf, operand1);
return result;
}
SpvId
spirv_builder_emit_triop(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1, SpvId operand2)
@ -612,6 +629,24 @@ spirv_builder_emit_triop(struct spirv_builder *b, SpvOp op, SpvId result_type,
return result;
}
SpvId
spirv_builder_emit_triop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1, SpvId operand2)
{
struct spirv_buffer *buf = op == SpvOpSpecConstantOp ? &b->types_const_defs : &b->instructions;
SpvId result = spirv_builder_new_id(b);
spirv_buffer_prepare(buf, b->mem_ctx, 7);
spirv_buffer_emit_word(buf, op | (7 << 16));
spirv_buffer_emit_word(buf, result_type);
spirv_buffer_emit_word(buf, result);
spirv_buffer_emit_word(buf, spirv_builder_const_uint(b, 32, SpvScopeSubgroup));
spirv_buffer_emit_word(buf, operand0);
spirv_buffer_emit_word(buf, operand1);
spirv_buffer_emit_word(buf, operand2);
return result;
}
SpvId
spirv_builder_emit_quadop(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1, SpvId operand2, SpvId operand3)

View file

@ -235,10 +235,18 @@ SpvId
spirv_builder_emit_binop(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1);
SpvId
spirv_builder_emit_binop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1);
SpvId
spirv_builder_emit_triop(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1, SpvId operand2);
SpvId
spirv_builder_emit_triop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1, SpvId operand2);
SpvId
spirv_builder_emit_quadop(struct spirv_builder *b, SpvOp op, SpvId result_type,
SpvId operand0, SpvId operand1, SpvId operand2, SpvId operand3);