mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-05 00:58:05 +02:00
spirv/subgroups: Refactor to use vtn_push_ssa
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5278>
This commit is contained in:
parent
ea246c3950
commit
00af1128a9
1 changed files with 42 additions and 36 deletions
|
|
@ -23,10 +23,9 @@
|
|||
|
||||
#include "vtn_private.h"
|
||||
|
||||
static void
|
||||
static struct vtn_ssa_value *
|
||||
vtn_build_subgroup_instr(struct vtn_builder *b,
|
||||
nir_intrinsic_op nir_op,
|
||||
struct vtn_ssa_value *dst,
|
||||
struct vtn_ssa_value *src0,
|
||||
nir_ssa_def *index,
|
||||
unsigned const_idx0,
|
||||
|
|
@ -39,14 +38,16 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
|||
if (index && index->bit_size != 32)
|
||||
index = nir_u2u32(&b->nb, index);
|
||||
|
||||
struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
|
||||
|
||||
vtn_assert(dst->type == src0->type);
|
||||
if (!glsl_type_is_vector_or_scalar(dst->type)) {
|
||||
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
|
||||
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
|
||||
src0->elems[i], index,
|
||||
const_idx0, const_idx1);
|
||||
dst->elems[0] =
|
||||
vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
|
||||
const_idx0, const_idx1);
|
||||
}
|
||||
return;
|
||||
return dst;
|
||||
}
|
||||
|
||||
nir_intrinsic_instr *intrin =
|
||||
|
|
@ -65,33 +66,33 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
|||
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||
|
||||
dst->def = &intrin->dest.ssa;
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
|
||||
|
||||
val->ssa = vtn_create_ssa_value(b, val->type->type);
|
||||
struct vtn_type *dest_type = vtn_get_type(b, w[1]);
|
||||
|
||||
switch (opcode) {
|
||||
case SpvOpGroupNonUniformElect: {
|
||||
vtn_fail_if(val->type->type != glsl_bool_type(),
|
||||
vtn_fail_if(dest_type->type != glsl_bool_type(),
|
||||
"OpGroupNonUniformElect must return a Bool");
|
||||
nir_intrinsic_instr *elect =
|
||||
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
|
||||
nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
|
||||
val->type->type, NULL);
|
||||
dest_type->type, NULL);
|
||||
nir_builder_instr_insert(&b->nb, &elect->instr);
|
||||
val->ssa->def = &elect->dest.ssa;
|
||||
vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpGroupNonUniformBallot:
|
||||
case SpvOpSubgroupBallotKHR: {
|
||||
bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
|
||||
vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
|
||||
vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
|
||||
"OpGroupNonUniformBallot must return a uvec4");
|
||||
nir_intrinsic_instr *ballot =
|
||||
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
|
||||
|
|
@ -99,7 +100,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
|
||||
ballot->num_components = 4;
|
||||
nir_builder_instr_insert(&b->nb, &ballot->instr);
|
||||
val->ssa->def = &ballot->dest.ssa;
|
||||
vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -116,10 +117,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
|
||||
|
||||
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
|
||||
val->type->type, NULL);
|
||||
dest_type->type, NULL);
|
||||
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||
|
||||
val->ssa->def = &intrin->dest.ssa;
|
||||
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -171,19 +172,20 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
intrin->src[1] = nir_src_for_ssa(src1);
|
||||
|
||||
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
|
||||
val->type->type, NULL);
|
||||
dest_type->type, NULL);
|
||||
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||
|
||||
val->ssa->def = &intrin->dest.ssa;
|
||||
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpGroupNonUniformBroadcastFirst:
|
||||
case SpvOpSubgroupFirstInvocationKHR: {
|
||||
bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
|
||||
val->ssa, vtn_ssa_value(b, w[3 + has_scope]),
|
||||
NULL, 0, 0);
|
||||
vtn_push_ssa_value(b, w[2],
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
|
||||
vtn_ssa_value(b, w[3 + has_scope]),
|
||||
NULL, 0, 0));
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -191,9 +193,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpGroupBroadcast:
|
||||
case SpvOpSubgroupReadInvocationKHR: {
|
||||
bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
|
||||
val->ssa, vtn_ssa_value(b, w[3 + has_scope]),
|
||||
vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0);
|
||||
vtn_push_ssa_value(b, w[2],
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
|
||||
vtn_ssa_value(b, w[3 + has_scope]),
|
||||
vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -205,7 +208,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
case SpvOpSubgroupAllKHR:
|
||||
case SpvOpSubgroupAnyKHR:
|
||||
case SpvOpSubgroupAllEqualKHR: {
|
||||
vtn_fail_if(val->type->type != glsl_bool_type(),
|
||||
vtn_fail_if(dest_type->type != glsl_bool_type(),
|
||||
"OpGroupNonUniform(All|Any|AllEqual) must return a bool");
|
||||
nir_intrinsic_op op;
|
||||
switch (opcode) {
|
||||
|
|
@ -262,10 +265,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
intrin->num_components = src0->num_components;
|
||||
intrin->src[0] = nir_src_for_ssa(src0);
|
||||
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
|
||||
val->type->type, NULL);
|
||||
dest_type->type, NULL);
|
||||
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||
|
||||
val->ssa->def = &intrin->dest.ssa;
|
||||
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -290,15 +293,17 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
default:
|
||||
unreachable("Invalid opcode");
|
||||
}
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
||||
vtn_get_nir_ssa(b, w[5]), 0, 0);
|
||||
vtn_push_ssa_value(b, w[2],
|
||||
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
|
||||
vtn_get_nir_ssa(b, w[5]), 0, 0));
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpGroupNonUniformQuadBroadcast:
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
|
||||
val->ssa, vtn_ssa_value(b, w[4]),
|
||||
vtn_get_nir_ssa(b, w[5]), 0, 0);
|
||||
vtn_push_ssa_value(b, w[2],
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
|
||||
vtn_ssa_value(b, w[4]),
|
||||
vtn_get_nir_ssa(b, w[5]), 0, 0));
|
||||
break;
|
||||
|
||||
case SpvOpGroupNonUniformQuadSwap: {
|
||||
|
|
@ -317,8 +322,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
default:
|
||||
vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
|
||||
}
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
||||
NULL, 0, 0);
|
||||
vtn_push_ssa_value(b, w[2],
|
||||
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -439,8 +444,9 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||
unreachable("Invalid group operation");
|
||||
}
|
||||
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
|
||||
NULL, reduction_op, cluster_size);
|
||||
vtn_push_ssa_value(b, w[2],
|
||||
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
|
||||
reduction_op, cluster_size));
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue