zink: unify ntv code for loading shared/scratch memory

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37358>
This commit is contained in:
Mike Blumenkrantz 2025-09-05 08:27:36 -04:00
parent d74eff651b
commit ea4d64531d

View file

@ -2320,25 +2320,24 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
}
static void
emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
emit_load_special(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId block, SpvStorageClass storage_class)
{
SpvId dest_type = get_def_type(ctx, &intr->def, nir_type_uint);
unsigned num_components = intr->def.num_components;
unsigned bit_size = intr->def.bit_size;
SpvId uint_type = get_uvec_type(ctx, bit_size, 1);
SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
SpvStorageClassWorkgroup,
storage_class,
uint_type);
nir_alu_type atype;
SpvId offset = get_src(ctx, &intr->src[0], &atype);
if (atype == nir_type_float)
offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1);
SpvId constituents[NIR_MAX_VEC_COMPONENTS];
SpvId shared_block = get_shared_block(ctx, bit_size);
/* need to convert array -> vec */
for (unsigned i = 0; i < num_components; i++) {
SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type,
shared_block, &offset, 1);
block, &offset, 1);
constituents[i] = spirv_builder_emit_load(&ctx->builder, uint_type, member);
offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, 1));
}
@ -2350,6 +2349,14 @@ emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
store_def(ctx, intr->def.index, result, nir_type_uint);
}
static void
emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
unsigned bit_size = intr->def.bit_size;
SpvId shared_block = get_shared_block(ctx, bit_size);
emit_load_special(ctx, intr, shared_block, SpvStorageClassWorkgroup);
}
static void
emit_store_special(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId block, SpvStorageClass storage_class)
{
@ -2391,32 +2398,9 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
static void
emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
SpvId dest_type = get_def_type(ctx, &intr->def, nir_type_uint);
unsigned num_components = intr->def.num_components;
unsigned bit_size = intr->def.bit_size;
SpvId uint_type = get_uvec_type(ctx, bit_size, 1);
SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
SpvStorageClassPrivate,
uint_type);
nir_alu_type atype;
SpvId offset = get_src(ctx, &intr->src[0], &atype);
if (atype != nir_type_uint)
offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1);
SpvId constituents[NIR_MAX_VEC_COMPONENTS];
SpvId scratch_block = get_scratch_block(ctx, bit_size);
/* need to convert array -> vec */
for (unsigned i = 0; i < num_components; i++) {
SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type,
scratch_block, &offset, 1);
constituents[i] = spirv_builder_emit_load(&ctx->builder, uint_type, member);
offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, 1));
}
SpvId result;
if (num_components > 1)
result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type, constituents, num_components);
else
result = constituents[0];
store_def(ctx, intr->def.index, result, nir_type_uint);
emit_load_special(ctx, intr, scratch_block, SpvStorageClassPrivate);
}
static void