diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 033cebdcbd6..6e71d358181 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -89,6 +89,7 @@ struct ntv_context { SpvId loop_break, loop_cont; SpvId shared_block_var[5]; //8, 16, 32, unused, 64 + SpvId scratch_block_var[5]; //8, 16, 32, unused, 64 SpvId front_face_var, instance_id_var, vertex_id_var, primitive_id_var, invocation_id_var, // geometry @@ -500,6 +501,34 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) return ret; } +static void +create_scratch_block(struct ntv_context *ctx, unsigned scratch_size, unsigned bit_size) +{ + unsigned idx = bit_size >> 4; + SpvId type = spirv_builder_type_uint(&ctx->builder, bit_size); + unsigned block_size = scratch_size / (bit_size / 8); + assert(block_size); + SpvId array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size)); + spirv_builder_emit_array_stride(&ctx->builder, array, bit_size / 8); + SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassPrivate, + array); + ctx->scratch_block_var[idx] = spirv_builder_emit_var(&ctx->builder, ptr_type, SpvStorageClassPrivate); + if (ctx->spirv_1_4_interfaces) { + assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces)); + ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->scratch_block_var[idx]; + } +} + +static SpvId +get_scratch_block(struct ntv_context *ctx, unsigned bit_size) +{ + unsigned idx = bit_size >> 4; + if (!ctx->scratch_block_var[idx]) + create_scratch_block(ctx, ctx->nir->scratch_size, bit_size); + return ctx->scratch_block_var[idx]; +} + static void create_shared_block(struct ntv_context *ctx, unsigned shared_size, unsigned bit_size) { @@ -2519,6 +2548,59 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) } } +static void +emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint); + unsigned num_components = nir_dest_num_components(intr->dest); + unsigned bit_size = nir_dest_bit_size(intr->dest); + SpvId uint_type = get_uvec_type(ctx, bit_size, 1); + SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassPrivate, + uint_type); + SpvId offset = get_src(ctx, &intr->src[0]); + SpvId constituents[NIR_MAX_VEC_COMPONENTS]; + SpvId scratch_block = get_scratch_block(ctx, bit_size); + /* need to convert array -> vec */ + for (unsigned i = 0; i < num_components; i++) { + SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, + scratch_block, &offset, 1); + constituents[i] = spirv_builder_emit_load(&ctx->builder, uint_type, member); + offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, 1)); + } + SpvId result; + if (num_components > 1) + result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type, constituents, num_components); + else + result = bitcast_to_uvec(ctx, constituents[0], bit_size, num_components); + store_dest(ctx, &intr->dest, result, nir_type_uint); +} + +static void +emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + SpvId src = get_src(ctx, &intr->src[0]); + + unsigned wrmask = nir_intrinsic_write_mask(intr); + unsigned bit_size = nir_src_bit_size(intr->src[0]); + SpvId uint_type = get_uvec_type(ctx, bit_size, 1); + SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassPrivate, + uint_type); + SpvId offset = get_src(ctx, &intr->src[1]); + SpvId scratch_block = get_scratch_block(ctx, bit_size); + /* this is a partial write, so we have to loop and do a per-component write */ + u_foreach_bit(i, wrmask) { + SpvId scratch_offset = emit_binop(ctx, SpvOpIAdd, spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, i)); + SpvId val = src; + if (nir_src_num_components(intr->src[0]) != 1) + val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); + SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, + scratch_block, &scratch_offset, 1); + spirv_builder_emit_store(&ctx->builder, member, val); + } +} + static void emit_load_push_const(struct ntv_context *ctx, nir_intrinsic_instr *intr) { @@ -2572,6 +2654,33 @@ emit_load_push_const(struct ntv_context *ctx, nir_intrinsic_instr *intr) store_dest(ctx, &intr->dest, result, nir_type_uint); } +static void +emit_load_global(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityPhysicalStorageBufferAddresses); + SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint); + SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassPhysicalStorageBuffer, + dest_type); + SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[0])); + SpvId result = spirv_builder_emit_load(&ctx->builder, dest_type, ptr); + store_dest(ctx, &intr->dest, result, nir_type_uint); +} + +static void +emit_store_global(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityPhysicalStorageBufferAddresses); + unsigned bit_size = nir_src_bit_size(intr->src[0]); + SpvId dest_type = get_uvec_type(ctx, bit_size, 1); + SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassPhysicalStorageBuffer, + dest_type); + SpvId param = get_src(ctx, &intr->src[0]); + SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[1])); + spirv_builder_emit_store(&ctx->builder, ptr, param); +} + static SpvId create_builtin_var(struct ntv_context *ctx, SpvId var_type, SpvStorageClass storage_class, @@ -3093,6 +3202,14 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) emit_load_push_const(ctx, intr); break; + case nir_intrinsic_load_global: + emit_load_global(ctx, intr); + break; + + case nir_intrinsic_store_global: + emit_store_global(ctx, intr); + break; + case nir_intrinsic_load_front_face: emit_load_front_face(ctx, intr); break; @@ -3361,6 +3478,14 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) emit_store_shared(ctx, intr); break; + case nir_intrinsic_load_scratch: + emit_load_scratch(ctx, intr); + break; + + case nir_intrinsic_store_scratch: + emit_store_scratch(ctx, intr); + break; + case nir_intrinsic_shader_clock: emit_shader_clock(ctx, intr); break; @@ -4297,7 +4422,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ if (s->info.cs.ptr_size == 32) model = SpvAddressingModelPhysical32; else if (s->info.cs.ptr_size == 64) - model = SpvAddressingModelPhysical64; + model = SpvAddressingModelPhysicalStorageBuffer64; else model = SpvAddressingModelLogical; spirv_builder_emit_mem_model(&ctx.builder, model, diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index 5d5582e65a8..5761bd9468f 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -62,20 +62,6 @@ create_vs_pushconst(nir_shader *nir) vs_pushconst->data.location = INT_MAX; //doesn't really matter } -static void -create_cs_pushconst(nir_shader *nir) -{ - nir_variable *cs_pushconst; - /* create compatible layout for the ntv push constant loader */ - struct glsl_struct_field *fields = rzalloc_size(nir, 1 * sizeof(struct glsl_struct_field)); - fields[0].type = glsl_array_type(glsl_uint_type(), 1, 0); - fields[0].name = ralloc_asprintf(nir, "work_dim"); - fields[0].offset = 0; - cs_pushconst = nir_variable_create(nir, nir_var_mem_push_const, - glsl_struct_type(fields, 1, "struct", false), "cs_pushconst"); - cs_pushconst->data.location = INT_MAX; //doesn't really matter -} - static bool reads_work_dim(nir_shader *shader) { @@ -3155,8 +3141,6 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, else if (nir->info.stage == MESA_SHADER_TESS_CTRL || nir->info.stage == MESA_SHADER_TESS_EVAL) NIR_PASS_V(nir, nir_lower_io_arrays_to_elements_no_indirects, false); - else if (nir->info.stage == MESA_SHADER_KERNEL) - create_cs_pushconst(nir); if (nir->info.stage < MESA_SHADER_FRAGMENT) have_psiz = check_psiz(nir);