zink: use implicit stride in ntv for temp vars

APPARENTLY explicit stride is illegal for temp vars because they should
just be using the element stride implicitly

this makes total sense and is very obvious

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33651>
This commit is contained in:
Mike Blumenkrantz 2025-03-28 13:57:34 -04:00 committed by Marge Bot
parent b4e3535650
commit 0b7611824a

View file

@ -46,7 +46,7 @@ struct ntv_context {
struct spirv_builder builder; struct spirv_builder builder;
nir_shader *nir; nir_shader *nir;
struct hash_table *glsl_types; struct hash_table *glsl_types[2]; //[implicit_stride]
struct hash_table *bo_struct_types; struct hash_table *bo_struct_types;
struct hash_table *bo_array_types; struct hash_table *bo_array_types;
@ -567,7 +567,7 @@ get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
} }
static SpvId static SpvId
get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type, bool implicit_stride)
{ {
assert(type); assert(type);
if (glsl_type_is_scalar(type)) if (glsl_type_is_scalar(type))
@ -590,25 +590,27 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
*/ */
struct hash_entry *entry = struct hash_entry *entry =
_mesa_hash_table_search(ctx->glsl_types, type); _mesa_hash_table_search(ctx->glsl_types[implicit_stride], type);
if (entry) if (entry)
return (SpvId)(uintptr_t)entry->data; return (SpvId)(uintptr_t)entry->data;
SpvId ret; SpvId ret;
if (glsl_type_is_array(type)) { if (glsl_type_is_array(type)) {
SpvId element_type = get_glsl_type(ctx, glsl_get_array_element(type)); SpvId element_type = get_glsl_type(ctx, glsl_get_array_element(type), implicit_stride);
if (glsl_type_is_unsized_array(type)) if (glsl_type_is_unsized_array(type))
ret = spirv_builder_type_runtime_array(&ctx->builder, element_type); ret = spirv_builder_type_runtime_array(&ctx->builder, element_type);
else else
ret = spirv_builder_type_array(&ctx->builder, ret = spirv_builder_type_array(&ctx->builder,
element_type, element_type,
emit_uint_const(ctx, 32, glsl_get_length(type))); emit_uint_const(ctx, 32, glsl_get_length(type)));
if (!implicit_stride) {
uint32_t stride = glsl_get_explicit_stride(type); uint32_t stride = glsl_get_explicit_stride(type);
if (!stride && glsl_type_is_scalar(glsl_get_array_element(type))) { if (!stride && glsl_type_is_scalar(glsl_get_array_element(type))) {
stride = MAX2(glsl_get_bit_size(glsl_get_array_element(type)) / 8, 1); stride = MAX2(glsl_get_bit_size(glsl_get_array_element(type)) / 8, 1);
} }
if (stride) if (stride)
spirv_builder_emit_array_stride(&ctx->builder, ret, stride); spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
}
} else if (glsl_type_is_struct_or_ifc(type)) { } else if (glsl_type_is_struct_or_ifc(type)) {
const unsigned length = glsl_get_length(type); const unsigned length = glsl_get_length(type);
@ -623,7 +625,7 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
} }
for (unsigned i = 0; i < glsl_get_length(type); i++) for (unsigned i = 0; i < glsl_get_length(type); i++)
types[i] = get_glsl_type(ctx, glsl_get_struct_field(type, i)); types[i] = get_glsl_type(ctx, glsl_get_struct_field(type, i), implicit_stride);
ret = spirv_builder_type_struct(&ctx->builder, types, ret = spirv_builder_type_struct(&ctx->builder, types,
glsl_get_length(type)); glsl_get_length(type));
for (unsigned i = 0; i < glsl_get_length(type); i++) { for (unsigned i = 0; i < glsl_get_length(type); i++) {
@ -634,7 +636,7 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
} else } else
unreachable("Unhandled GLSL type"); unreachable("Unhandled GLSL type");
_mesa_hash_table_insert(ctx->glsl_types, type, (void *)(uintptr_t)ret); _mesa_hash_table_insert(ctx->glsl_types[implicit_stride], type, (void *)(uintptr_t)ret);
return ret; return ret;
} }
@ -742,7 +744,7 @@ get_shared_block(struct ntv_context *ctx, unsigned bit_size)
static SpvId static SpvId
input_var_init(struct ntv_context *ctx, struct nir_variable *var) input_var_init(struct ntv_context *ctx, struct nir_variable *var)
{ {
SpvId var_type = get_glsl_type(ctx, var->type); SpvId var_type = get_glsl_type(ctx, var->type, false);
SpvStorageClass sc = get_storage_class(var); SpvStorageClass sc = get_storage_class(var);
if (sc == SpvStorageClassPushConstant) if (sc == SpvStorageClassPushConstant)
spirv_builder_emit_decoration(&ctx->builder, var_type, SpvDecorationBlock); spirv_builder_emit_decoration(&ctx->builder, var_type, SpvDecorationBlock);
@ -854,7 +856,7 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var)
static void static void
emit_output(struct ntv_context *ctx, struct nir_variable *var) emit_output(struct ntv_context *ctx, struct nir_variable *var)
{ {
SpvId var_type = get_glsl_type(ctx, var->type); SpvId var_type = get_glsl_type(ctx, var->type, false);
/* SampleMask is always an array in spirv */ /* SampleMask is always an array in spirv */
if (ctx->stage == MESA_SHADER_FRAGMENT && var->data.location == FRAG_RESULT_SAMPLE_MASK) if (ctx->stage == MESA_SHADER_FRAGMENT && var->data.location == FRAG_RESULT_SAMPLE_MASK)
@ -947,7 +949,7 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
static void static void
emit_shader_temp(struct ntv_context *ctx, struct nir_variable *var) emit_shader_temp(struct ntv_context *ctx, struct nir_variable *var)
{ {
SpvId var_type = get_glsl_type(ctx, var->type); SpvId var_type = get_glsl_type(ctx, var->type, true);
SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
SpvStorageClassPrivate, SpvStorageClassPrivate,
@ -966,7 +968,7 @@ emit_shader_temp(struct ntv_context *ctx, struct nir_variable *var)
static void static void
emit_temp(struct ntv_context *ctx, struct nir_variable *var) emit_temp(struct ntv_context *ctx, struct nir_variable *var)
{ {
SpvId var_type = get_glsl_type(ctx, var->type); SpvId var_type = get_glsl_type(ctx, var->type, true);
SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
SpvStorageClassFunction, SpvStorageClassFunction,
@ -2246,7 +2248,7 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
glsl_get_sampler_dim(gtype) == GLSL_SAMPLER_DIM_BUF); glsl_get_sampler_dim(gtype) == GLSL_SAMPLER_DIM_BUF);
atype = nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(gtype)); atype = nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(gtype));
} else { } else {
type = get_glsl_type(ctx, deref->type); type = get_glsl_type(ctx, deref->type, deref->modes & (nir_var_shader_temp | nir_var_function_temp));
atype = get_nir_alu_type(deref->type); atype = get_nir_alu_type(deref->type);
} }
SpvId result; SpvId result;
@ -2266,8 +2268,8 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
SpvId src = get_src(ctx, &intr->src[1], &stype); SpvId src = get_src(ctx, &intr->src[1], &stype);
const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type;
SpvId type = get_glsl_type(ctx, gtype);
nir_variable *var = nir_intrinsic_get_var(intr, 0); nir_variable *var = nir_intrinsic_get_var(intr, 0);
SpvId type = get_glsl_type(ctx, gtype, var->data.mode & (nir_var_shader_temp | nir_var_function_temp));
unsigned wrmask = nir_intrinsic_write_mask(intr); unsigned wrmask = nir_intrinsic_write_mask(intr);
if (!glsl_type_is_scalar(gtype) && if (!glsl_type_is_scalar(gtype) &&
wrmask != BITFIELD_MASK(glsl_type_is_array(gtype) ? glsl_get_aoa_size(gtype) : glsl_get_vector_elements(gtype))) { wrmask != BITFIELD_MASK(glsl_type_is_array(gtype) ? glsl_get_aoa_size(gtype) : glsl_get_vector_elements(gtype))) {
@ -2281,7 +2283,7 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
result_type = get_glsl_basetype(ctx, glsl_get_base_type(gtype)); result_type = get_glsl_basetype(ctx, glsl_get_base_type(gtype));
member_type = get_alu_type(ctx, stype, 1, glsl_get_bit_size(gtype)); member_type = get_alu_type(ctx, stype, 1, glsl_get_bit_size(gtype));
} else } else
member_type = result_type = get_glsl_type(ctx, glsl_get_array_element(gtype)); member_type = result_type = get_glsl_type(ctx, glsl_get_array_element(gtype), var->data.mode & (nir_var_shader_temp | nir_var_function_temp));
SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
get_storage_class(var), get_storage_class(var),
result_type); result_type);
@ -2739,9 +2741,9 @@ emit_interpolate(struct ntv_context *ctx, nir_intrinsic_instr *intr)
assert(glsl_get_vector_elements(gtype) == intr->num_components); assert(glsl_get_vector_elements(gtype) == intr->num_components);
assert(ptype == get_nir_alu_type(gtype)); assert(ptype == get_nir_alu_type(gtype));
if (intr->intrinsic == nir_intrinsic_interp_deref_at_centroid) if (intr->intrinsic == nir_intrinsic_interp_deref_at_centroid)
result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, gtype), ptr); result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, gtype, false), ptr);
else else
result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, gtype), ptr, src1); result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, gtype, false), ptr, src1);
store_def(ctx, intr->def.index, result, ptype); store_def(ctx, intr->def.index, result, ptype);
} }
@ -4252,7 +4254,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
case nir_var_shader_in: case nir_var_shader_in:
case nir_var_shader_out: case nir_var_shader_out:
base = get_src(ctx, &deref->parent, &atype); base = get_src(ctx, &deref->parent, &atype);
type = get_glsl_type(ctx, deref->type); type = get_glsl_type(ctx, deref->type, var->data.mode & (nir_var_shader_temp | nir_var_function_temp));
break; break;
case nir_var_uniform: case nir_var_uniform:
@ -4297,7 +4299,7 @@ emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref)
SpvId index = emit_uint_const(ctx, 32, deref->strct.index); SpvId index = emit_uint_const(ctx, 32, deref->strct.index);
SpvId type = (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)) ? SpvId type = (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)) ?
get_bo_array_type(ctx, var) : get_bo_array_type(ctx, var) :
get_glsl_type(ctx, deref->type); get_glsl_type(ctx, deref->type, var->data.mode & (nir_var_shader_temp | nir_var_function_temp));
SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
storage_class, storage_class,
@ -4611,10 +4613,11 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, const s
ctx.have_spirv16 = spirv_version >= SPIRV_VERSION(1, 6); ctx.have_spirv16 = spirv_version >= SPIRV_VERSION(1, 6);
ctx.bindless_set_idx = sinfo->bindless_set_idx; ctx.bindless_set_idx = sinfo->bindless_set_idx;
ctx.glsl_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); ctx.glsl_types[0] = _mesa_pointer_hash_table_create(ctx.mem_ctx);
ctx.glsl_types[1] = _mesa_pointer_hash_table_create(ctx.mem_ctx);
ctx.bo_array_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); ctx.bo_array_types = _mesa_pointer_hash_table_create(ctx.mem_ctx);
ctx.bo_struct_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); ctx.bo_struct_types = _mesa_pointer_hash_table_create(ctx.mem_ctx);
if (!ctx.glsl_types || !ctx.bo_array_types || !ctx.bo_struct_types || if (!ctx.glsl_types[0] || !ctx.glsl_types[1] || !ctx.bo_array_types || !ctx.bo_struct_types ||
!_mesa_hash_table_init(&ctx.image_types, ctx.mem_ctx, _mesa_hash_pointer, _mesa_key_pointer_equal)) !_mesa_hash_table_init(&ctx.image_types, ctx.mem_ctx, _mesa_hash_pointer, _mesa_key_pointer_equal))
goto fail; goto fail;