spirv/subgroups: Refactor to use vtn_push_ssa

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5278>
This commit is contained in:
Jason Ekstrand 2020-05-29 14:40:12 -05:00
parent ea246c3950
commit 00af1128a9

View file

@ -23,10 +23,9 @@
#include "vtn_private.h"
static void
static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder *b,
nir_intrinsic_op nir_op,
struct vtn_ssa_value *dst,
struct vtn_ssa_value *src0,
nir_ssa_def *index,
unsigned const_idx0,
@ -39,14 +38,16 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
if (index && index->bit_size != 32)
index = nir_u2u32(&b->nb, index);
struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
vtn_assert(dst->type == src0->type);
if (!glsl_type_is_vector_or_scalar(dst->type)) {
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
src0->elems[i], index,
const_idx0, const_idx1);
dst->elems[0] =
vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
const_idx0, const_idx1);
}
return;
return dst;
}
nir_intrinsic_instr *intrin =
@ -65,33 +66,33 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
nir_builder_instr_insert(&b->nb, &intrin->instr);
dst->def = &intrin->dest.ssa;
return dst;
}
void
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
val->ssa = vtn_create_ssa_value(b, val->type->type);
struct vtn_type *dest_type = vtn_get_type(b, w[1]);
switch (opcode) {
case SpvOpGroupNonUniformElect: {
vtn_fail_if(val->type->type != glsl_bool_type(),
vtn_fail_if(dest_type->type != glsl_bool_type(),
"OpGroupNonUniformElect must return a Bool");
nir_intrinsic_instr *elect =
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
val->type->type, NULL);
dest_type->type, NULL);
nir_builder_instr_insert(&b->nb, &elect->instr);
val->ssa->def = &elect->dest.ssa;
vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
break;
}
case SpvOpGroupNonUniformBallot:
case SpvOpSubgroupBallotKHR: {
bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
"OpGroupNonUniformBallot must return a uvec4");
nir_intrinsic_instr *ballot =
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
@ -99,7 +100,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
ballot->num_components = 4;
nir_builder_instr_insert(&b->nb, &ballot->instr);
val->ssa->def = &ballot->dest.ssa;
vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
break;
}
@ -116,10 +117,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
val->type->type, NULL);
dest_type->type, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
break;
}
@ -171,19 +172,20 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
intrin->src[1] = nir_src_for_ssa(src1);
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
val->type->type, NULL);
dest_type->type, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
break;
}
case SpvOpGroupNonUniformBroadcastFirst:
case SpvOpSubgroupFirstInvocationKHR: {
bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
val->ssa, vtn_ssa_value(b, w[3 + has_scope]),
NULL, 0, 0);
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
vtn_ssa_value(b, w[3 + has_scope]),
NULL, 0, 0));
break;
}
@ -191,9 +193,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupBroadcast:
case SpvOpSubgroupReadInvocationKHR: {
bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
val->ssa, vtn_ssa_value(b, w[3 + has_scope]),
vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0);
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
vtn_ssa_value(b, w[3 + has_scope]),
vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
break;
}
@ -205,7 +208,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpSubgroupAllKHR:
case SpvOpSubgroupAnyKHR:
case SpvOpSubgroupAllEqualKHR: {
vtn_fail_if(val->type->type != glsl_bool_type(),
vtn_fail_if(dest_type->type != glsl_bool_type(),
"OpGroupNonUniform(All|Any|AllEqual) must return a bool");
nir_intrinsic_op op;
switch (opcode) {
@ -262,10 +265,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
intrin->num_components = src0->num_components;
intrin->src[0] = nir_src_for_ssa(src0);
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
val->type->type, NULL);
dest_type->type, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
break;
}
@ -290,15 +293,17 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
default:
unreachable("Invalid opcode");
}
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
vtn_get_nir_ssa(b, w[5]), 0, 0);
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
vtn_get_nir_ssa(b, w[5]), 0, 0));
break;
}
case SpvOpGroupNonUniformQuadBroadcast:
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
val->ssa, vtn_ssa_value(b, w[4]),
vtn_get_nir_ssa(b, w[5]), 0, 0);
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
vtn_ssa_value(b, w[4]),
vtn_get_nir_ssa(b, w[5]), 0, 0));
break;
case SpvOpGroupNonUniformQuadSwap: {
@ -317,8 +322,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
default:
vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
}
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
NULL, 0, 0);
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
break;
}
@ -439,8 +444,9 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
unreachable("Invalid group operation");
}
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
NULL, reduction_op, cluster_size);
vtn_push_ssa_value(b, w[2],
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
reduction_op, cluster_size));
break;
}