From d74eff651bc130d9c44cd1fa65e7fefb3fae20c6 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Fri, 5 Sep 2025 08:27:36 -0400 Subject: [PATCH] zink: unify ntv code for storing shared/scratch memory Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 48 +++++++------------ 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 0e27f109ec4..81159be1b18 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -2351,7 +2351,7 @@ emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) } static void -emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) +emit_store_special(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId block, SpvStorageClass storage_class) { nir_alu_type atype; SpvId src = get_src(ctx, &intr->src[0], &atype); @@ -2360,27 +2360,34 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) unsigned bit_size = nir_src_bit_size(intr->src[0]); SpvId uint_type = get_uvec_type(ctx, bit_size, 1); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, - SpvStorageClassWorkgroup, + storage_class, uint_type); nir_alu_type otype; SpvId offset = get_src(ctx, &intr->src[1], &otype); - if (otype == nir_type_float) - offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); - SpvId shared_block = get_shared_block(ctx, bit_size); + if (otype != nir_type_uint) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[1]), 1); /* this is a partial write, so we have to loop and do a per-component write */ u_foreach_bit(i, wrmask) { - SpvId shared_offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, i)); + SpvId mask_offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, i)); SpvId val = src; if (nir_src_num_components(intr->src[0]) != 1) - val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); + val = spirv_builder_emit_composite_extract(&ctx->builder, get_alu_type(ctx, atype, 1, bit_size), src, &i, 1); if (atype != nir_type_uint) val = emit_bitcast(ctx, get_alu_type(ctx, nir_type_uint, 1, bit_size), val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, - shared_block, &shared_offset, 1); + block, &mask_offset, 1); spirv_builder_emit_store(&ctx->builder, member, val); } } +static void +emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + unsigned bit_size = nir_src_bit_size(intr->src[0]); + SpvId shared_block = get_shared_block(ctx, bit_size); + emit_store_special(ctx, intr, shared_block, SpvStorageClassWorkgroup); +} + static void emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) { @@ -2415,32 +2422,9 @@ emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - nir_alu_type atype; - SpvId src = get_src(ctx, &intr->src[0], &atype); - - unsigned wrmask = nir_intrinsic_write_mask(intr); unsigned bit_size = nir_src_bit_size(intr->src[0]); - SpvId uint_type = get_uvec_type(ctx, bit_size, 1); - SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, - SpvStorageClassPrivate, - uint_type); - nir_alu_type otype; - SpvId offset = get_src(ctx, &intr->src[1], &otype); - if (otype != nir_type_uint) - offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[1]), 1); SpvId scratch_block = get_scratch_block(ctx, bit_size); - /* this is a partial write, so we have to loop and do a per-component write */ - u_foreach_bit(i, wrmask) { - SpvId scratch_offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, i)); - SpvId val = src; - if (nir_src_num_components(intr->src[0]) != 1) - val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); - if (atype != nir_type_uint) - val = emit_bitcast(ctx, get_alu_type(ctx, nir_type_uint, 1, bit_size), val); - SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, - scratch_block, &scratch_offset, 1); - spirv_builder_emit_store(&ctx->builder, member, val); - } + emit_store_special(ctx, intr, scratch_block, SpvStorageClassPrivate); } static void