mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-04-25 11:20:49 +02:00
zink: implement ops for KHR_shader_subgroup
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31228>
This commit is contained in:
parent
b4f3136fea
commit
cee77ba6ec
3 changed files with 236 additions and 1 deletions
|
|
@ -109,7 +109,8 @@ struct ntv_context {
|
|||
subgroup_invocation_var,
|
||||
subgroup_le_mask_var,
|
||||
subgroup_lt_mask_var,
|
||||
subgroup_size_var;
|
||||
subgroup_size_var,
|
||||
num_subgroups_var;
|
||||
|
||||
SpvId discard_func;
|
||||
SpvId float_array_type[2];
|
||||
|
|
@ -3250,6 +3251,169 @@ emit_derivative(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
|||
store_def(ctx, intr->def.index, result, nir_type_float);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_subgroup(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
SpvOp op;
|
||||
|
||||
switch (nir_intrinsic_reduction_op(intr)) {
|
||||
#define SUBGROUP_CASE(nir, spirv) \
|
||||
case nir_op_##nir: \
|
||||
op = SpvOpGroupNonUniform##spirv; \
|
||||
break
|
||||
SUBGROUP_CASE(iadd, IAdd);
|
||||
SUBGROUP_CASE(fadd, FAdd);
|
||||
SUBGROUP_CASE(imul, IMul);
|
||||
SUBGROUP_CASE(fmul, FMul);
|
||||
SUBGROUP_CASE(imin, SMin);
|
||||
SUBGROUP_CASE(umin, UMin);
|
||||
SUBGROUP_CASE(fmin, FMin);
|
||||
SUBGROUP_CASE(imax, SMax);
|
||||
SUBGROUP_CASE(umax, UMax);
|
||||
SUBGROUP_CASE(fmax, FMax);
|
||||
#undef SUBGROUP_CASE
|
||||
|
||||
#define SUBGROUP_CASE_LOGICAL(nir, spirv) \
|
||||
case nir_op_##nir: \
|
||||
op = intr->src[0].ssa->bit_size != 1 ? SpvOpGroupNonUniformBitwise##spirv : SpvOpGroupNonUniformLogical##spirv; \
|
||||
break
|
||||
SUBGROUP_CASE_LOGICAL(iand, And);
|
||||
SUBGROUP_CASE_LOGICAL(ior, Or);
|
||||
SUBGROUP_CASE_LOGICAL(ixor, Xor);
|
||||
#undef SUBGROUP_CASE_LOGICAL
|
||||
default:
|
||||
fprintf(stderr, "emit_subgroup: reduction op not implemented (%s)\n",
|
||||
nir_intrinsic_infos[nir_intrinsic_reduction_op(intr)].name);
|
||||
unreachable("unhandled intrinsic");
|
||||
}
|
||||
|
||||
SpvGroupOperation groupop;
|
||||
unsigned cluster_size = 0;
|
||||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_reduce:
|
||||
cluster_size = nir_intrinsic_cluster_size(intr);
|
||||
groupop = cluster_size ? SpvGroupOperationClusteredReduce : SpvGroupOperationReduce;
|
||||
break;
|
||||
case nir_intrinsic_inclusive_scan:
|
||||
groupop = SpvGroupOperationInclusiveScan;
|
||||
break;
|
||||
case nir_intrinsic_exclusive_scan:
|
||||
groupop = SpvGroupOperationExclusiveScan;
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "emit_subgroup: not implemented (%s)\n",
|
||||
nir_intrinsic_infos[intr->intrinsic].name);
|
||||
unreachable("unhandled intrinsic");
|
||||
}
|
||||
spirv_builder_emit_cap(&ctx->builder, cluster_size ? SpvCapabilityGroupNonUniformClustered : SpvCapabilityGroupNonUniformArithmetic);
|
||||
|
||||
nir_alu_type atype;
|
||||
SpvId src0 = get_src(ctx, &intr->src[0], &atype);
|
||||
switch (op) {
|
||||
case SpvOpGroupNonUniformFAdd:
|
||||
case SpvOpGroupNonUniformFMul:
|
||||
case SpvOpGroupNonUniformFMin:
|
||||
case SpvOpGroupNonUniformFMax:
|
||||
atype = nir_type_float;
|
||||
src0 = emit_bitcast(ctx, get_def_type(ctx, intr->src[0].ssa, atype), src0);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
SpvId type = get_def_type(ctx, intr->src[0].ssa, atype);
|
||||
SpvId result = 0;
|
||||
if (cluster_size)
|
||||
result = spirv_builder_emit_triop_subgroup(&ctx->builder, op, type, groupop, src0, spirv_builder_const_uint(&ctx->builder, 32, cluster_size));
|
||||
else
|
||||
result = spirv_builder_emit_binop_subgroup(&ctx->builder, op, type, groupop, src0);
|
||||
store_def(ctx, intr->def.index, result, atype);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_subgroup_quad(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
SpvOp op;
|
||||
nir_alu_type atype, itype;
|
||||
SpvId src0 = get_src(ctx, &intr->src[0], &atype);
|
||||
SpvId src1 = 0;
|
||||
enum {
|
||||
QUAD_SWAP_HORIZONTAL,
|
||||
QUAD_SWAP_VERTICAL,
|
||||
QUAD_SWAP_DIAGONAL,
|
||||
};
|
||||
|
||||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_quad_broadcast:
|
||||
op = SpvOpGroupNonUniformQuadBroadcast;
|
||||
src1 = get_src(ctx, &intr->src[1], &itype);
|
||||
if (itype != nir_type_uint)
|
||||
src1 = emit_bitcast(ctx, get_def_type(ctx, intr->src[1].ssa, nir_type_uint), src1);
|
||||
break;
|
||||
case nir_intrinsic_quad_swap_horizontal:
|
||||
op = SpvOpGroupNonUniformQuadSwap;
|
||||
src1 = spirv_builder_const_uint(&ctx->builder, 32, QUAD_SWAP_HORIZONTAL);
|
||||
break;
|
||||
case nir_intrinsic_quad_swap_vertical:
|
||||
op = SpvOpGroupNonUniformQuadSwap;
|
||||
src1 = spirv_builder_const_uint(&ctx->builder, 32, QUAD_SWAP_VERTICAL);
|
||||
break;
|
||||
case nir_intrinsic_quad_swap_diagonal:
|
||||
op = SpvOpGroupNonUniformQuadSwap;
|
||||
src1 = spirv_builder_const_uint(&ctx->builder, 32, QUAD_SWAP_DIAGONAL);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "emit_subgroup_quad: not implemented (%s)\n",
|
||||
nir_intrinsic_infos[intr->intrinsic].name);
|
||||
unreachable("unhandled intrinsic");
|
||||
}
|
||||
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformQuad);
|
||||
|
||||
SpvId result = spirv_builder_emit_binop_subgroup(&ctx->builder, op, get_def_type(ctx, intr->src[0].ssa, atype), src0, src1);
|
||||
store_def(ctx, intr->def.index, result, atype);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_shuffle(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
SpvOp op;
|
||||
|
||||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_shuffle:
|
||||
op = SpvOpGroupNonUniformShuffle;
|
||||
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffle);
|
||||
break;
|
||||
case nir_intrinsic_shuffle_xor:
|
||||
op = SpvOpGroupNonUniformShuffleXor;
|
||||
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffle);
|
||||
break;
|
||||
case nir_intrinsic_shuffle_up:
|
||||
op = SpvOpGroupNonUniformShuffleUp;
|
||||
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffleRelative);
|
||||
break;
|
||||
case nir_intrinsic_shuffle_down:
|
||||
op = SpvOpGroupNonUniformShuffleDown;
|
||||
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformShuffleRelative);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "emit_shuffle: not implemented (%s)\n",
|
||||
nir_intrinsic_infos[intr->intrinsic].name);
|
||||
unreachable("unhandled intrinsic");
|
||||
}
|
||||
nir_alu_type atype, unused;
|
||||
SpvId src0 = get_src(ctx, &intr->src[0], &atype);
|
||||
SpvId src1 = get_src(ctx, &intr->src[1], &unused);
|
||||
|
||||
SpvId result = spirv_builder_emit_binop_subgroup(&ctx->builder, op, get_def_type(ctx, intr->src[0].ssa, atype), src0, src1);
|
||||
store_def(ctx, intr->def.index, result, atype);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_elect(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniform);
|
||||
SpvId result = spirv_builder_emit_unop_const(&ctx->builder, SpvOpGroupNonUniformElect, spirv_builder_type_bool(&ctx->builder), SpvScopeSubgroup);
|
||||
store_def(ctx, intr->def.index, result, nir_type_bool);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
|
|
@ -3463,6 +3627,7 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
|||
LOAD_SHADER_BALLOT(subgroup_le_mask, SubgroupLeMask);
|
||||
LOAD_SHADER_BALLOT(subgroup_lt_mask, SubgroupLtMask);
|
||||
LOAD_SHADER_BALLOT(subgroup_size, SubgroupSize);
|
||||
LOAD_SHADER_BALLOT(num_subgroups, NumSubgroups);
|
||||
|
||||
case nir_intrinsic_ballot:
|
||||
emit_ballot(ctx, intr);
|
||||
|
|
@ -3525,6 +3690,30 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
|
|||
emit_derivative(ctx, intr);
|
||||
break;
|
||||
|
||||
case nir_intrinsic_reduce:
|
||||
case nir_intrinsic_inclusive_scan:
|
||||
case nir_intrinsic_exclusive_scan:
|
||||
emit_subgroup(ctx, intr);
|
||||
break;
|
||||
|
||||
case nir_intrinsic_quad_broadcast:
|
||||
case nir_intrinsic_quad_swap_horizontal:
|
||||
case nir_intrinsic_quad_swap_vertical:
|
||||
case nir_intrinsic_quad_swap_diagonal:
|
||||
emit_subgroup_quad(ctx, intr);
|
||||
break;
|
||||
|
||||
case nir_intrinsic_shuffle:
|
||||
case nir_intrinsic_shuffle_xor:
|
||||
case nir_intrinsic_shuffle_up:
|
||||
case nir_intrinsic_shuffle_down:
|
||||
emit_shuffle(ctx, intr);
|
||||
break;
|
||||
|
||||
case nir_intrinsic_elect:
|
||||
emit_elect(ctx, intr);
|
||||
break;
|
||||
|
||||
default:
|
||||
fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
|
||||
nir_intrinsic_infos[intr->intrinsic].name);
|
||||
|
|
@ -4887,6 +5076,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, const s
|
|||
"main", ctx.entry_ifaces,
|
||||
ctx.num_entry_ifaces);
|
||||
|
||||
if (ctx.num_subgroups_var)
|
||||
spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGroupNonUniform);
|
||||
|
||||
size_t num_words = spirv_builder_get_num_words(&ctx.builder);
|
||||
|
||||
ret = ralloc(NULL, struct spirv_shader);
|
||||
|
|
|
|||
|
|
@ -595,6 +595,23 @@ spirv_builder_emit_binop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
|||
return result;
|
||||
}
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_binop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1)
|
||||
{
|
||||
struct spirv_buffer *buf = op == SpvOpSpecConstantOp ? &b->types_const_defs : &b->instructions;
|
||||
|
||||
SpvId result = spirv_builder_new_id(b);
|
||||
spirv_buffer_prepare(buf, b->mem_ctx, 6);
|
||||
spirv_buffer_emit_word(buf, op | (6 << 16));
|
||||
spirv_buffer_emit_word(buf, result_type);
|
||||
spirv_buffer_emit_word(buf, result);
|
||||
spirv_buffer_emit_word(buf, spirv_builder_const_uint(b, 32, SpvScopeSubgroup));
|
||||
spirv_buffer_emit_word(buf, operand0);
|
||||
spirv_buffer_emit_word(buf, operand1);
|
||||
return result;
|
||||
}
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_triop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1, SpvId operand2)
|
||||
|
|
@ -612,6 +629,24 @@ spirv_builder_emit_triop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
|||
return result;
|
||||
}
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_triop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1, SpvId operand2)
|
||||
{
|
||||
struct spirv_buffer *buf = op == SpvOpSpecConstantOp ? &b->types_const_defs : &b->instructions;
|
||||
|
||||
SpvId result = spirv_builder_new_id(b);
|
||||
spirv_buffer_prepare(buf, b->mem_ctx, 7);
|
||||
spirv_buffer_emit_word(buf, op | (7 << 16));
|
||||
spirv_buffer_emit_word(buf, result_type);
|
||||
spirv_buffer_emit_word(buf, result);
|
||||
spirv_buffer_emit_word(buf, spirv_builder_const_uint(b, 32, SpvScopeSubgroup));
|
||||
spirv_buffer_emit_word(buf, operand0);
|
||||
spirv_buffer_emit_word(buf, operand1);
|
||||
spirv_buffer_emit_word(buf, operand2);
|
||||
return result;
|
||||
}
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_quadop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1, SpvId operand2, SpvId operand3)
|
||||
|
|
|
|||
|
|
@ -235,10 +235,18 @@ SpvId
|
|||
spirv_builder_emit_binop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1);
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_binop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1);
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_triop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1, SpvId operand2);
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_triop_subgroup(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1, SpvId operand2);
|
||||
|
||||
SpvId
|
||||
spirv_builder_emit_quadop(struct spirv_builder *b, SpvOp op, SpvId result_type,
|
||||
SpvId operand0, SpvId operand1, SpvId operand2, SpvId operand3);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue