diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 17722884818..246a552a6e5 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -46,7 +46,7 @@ struct ntv_context { struct spirv_builder builder; nir_shader *nir; - struct hash_table *glsl_types; + struct hash_table *glsl_types[2]; //[implicit_stride] struct hash_table *bo_struct_types; struct hash_table *bo_array_types; @@ -567,7 +567,7 @@ get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type) } static SpvId -get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) +get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type, bool implicit_stride) { assert(type); if (glsl_type_is_scalar(type)) @@ -590,25 +590,27 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) */ struct hash_entry *entry = - _mesa_hash_table_search(ctx->glsl_types, type); + _mesa_hash_table_search(ctx->glsl_types[implicit_stride], type); if (entry) return (SpvId)(uintptr_t)entry->data; SpvId ret; if (glsl_type_is_array(type)) { - SpvId element_type = get_glsl_type(ctx, glsl_get_array_element(type)); + SpvId element_type = get_glsl_type(ctx, glsl_get_array_element(type), implicit_stride); if (glsl_type_is_unsized_array(type)) ret = spirv_builder_type_runtime_array(&ctx->builder, element_type); else ret = spirv_builder_type_array(&ctx->builder, element_type, emit_uint_const(ctx, 32, glsl_get_length(type))); - uint32_t stride = glsl_get_explicit_stride(type); - if (!stride && glsl_type_is_scalar(glsl_get_array_element(type))) { - stride = MAX2(glsl_get_bit_size(glsl_get_array_element(type)) / 8, 1); + if (!implicit_stride) { + uint32_t stride = glsl_get_explicit_stride(type); + if (!stride && glsl_type_is_scalar(glsl_get_array_element(type))) { + stride = MAX2(glsl_get_bit_size(glsl_get_array_element(type)) / 8, 1); + } + if (stride) + spirv_builder_emit_array_stride(&ctx->builder, ret, stride); } - if (stride) - spirv_builder_emit_array_stride(&ctx->builder, ret, stride); } else if (glsl_type_is_struct_or_ifc(type)) { const unsigned length = glsl_get_length(type); @@ -623,7 +625,7 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) } for (unsigned i = 0; i < glsl_get_length(type); i++) - types[i] = get_glsl_type(ctx, glsl_get_struct_field(type, i)); + types[i] = get_glsl_type(ctx, glsl_get_struct_field(type, i), implicit_stride); ret = spirv_builder_type_struct(&ctx->builder, types, glsl_get_length(type)); for (unsigned i = 0; i < glsl_get_length(type); i++) { @@ -634,7 +636,7 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) } else unreachable("Unhandled GLSL type"); - _mesa_hash_table_insert(ctx->glsl_types, type, (void *)(uintptr_t)ret); + _mesa_hash_table_insert(ctx->glsl_types[implicit_stride], type, (void *)(uintptr_t)ret); return ret; } @@ -742,7 +744,7 @@ get_shared_block(struct ntv_context *ctx, unsigned bit_size) static SpvId input_var_init(struct ntv_context *ctx, struct nir_variable *var) { - SpvId var_type = get_glsl_type(ctx, var->type); + SpvId var_type = get_glsl_type(ctx, var->type, false); SpvStorageClass sc = get_storage_class(var); if (sc == SpvStorageClassPushConstant) spirv_builder_emit_decoration(&ctx->builder, var_type, SpvDecorationBlock); @@ -854,7 +856,7 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var) static void emit_output(struct ntv_context *ctx, struct nir_variable *var) { - SpvId var_type = get_glsl_type(ctx, var->type); + SpvId var_type = get_glsl_type(ctx, var->type, false); /* SampleMask is always an array in spirv */ if (ctx->stage == MESA_SHADER_FRAGMENT && var->data.location == FRAG_RESULT_SAMPLE_MASK) @@ -947,7 +949,7 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var) static void emit_shader_temp(struct ntv_context *ctx, struct nir_variable *var) { - SpvId var_type = get_glsl_type(ctx, var->type); + SpvId var_type = get_glsl_type(ctx, var->type, true); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPrivate, @@ -966,7 +968,7 @@ emit_shader_temp(struct ntv_context *ctx, struct nir_variable *var) static void emit_temp(struct ntv_context *ctx, struct nir_variable *var) { - SpvId var_type = get_glsl_type(ctx, var->type); + SpvId var_type = get_glsl_type(ctx, var->type, true); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassFunction, @@ -2246,7 +2248,7 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) glsl_get_sampler_dim(gtype) == GLSL_SAMPLER_DIM_BUF); atype = nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(gtype)); } else { - type = get_glsl_type(ctx, deref->type); + type = get_glsl_type(ctx, deref->type, deref->modes & (nir_var_shader_temp | nir_var_function_temp)); atype = get_nir_alu_type(deref->type); } SpvId result; @@ -2266,8 +2268,8 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId src = get_src(ctx, &intr->src[1], &stype); const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; - SpvId type = get_glsl_type(ctx, gtype); nir_variable *var = nir_intrinsic_get_var(intr, 0); + SpvId type = get_glsl_type(ctx, gtype, var->data.mode & (nir_var_shader_temp | nir_var_function_temp)); unsigned wrmask = nir_intrinsic_write_mask(intr); if (!glsl_type_is_scalar(gtype) && wrmask != BITFIELD_MASK(glsl_type_is_array(gtype) ? glsl_get_aoa_size(gtype) : glsl_get_vector_elements(gtype))) { @@ -2281,7 +2283,7 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) result_type = get_glsl_basetype(ctx, glsl_get_base_type(gtype)); member_type = get_alu_type(ctx, stype, 1, glsl_get_bit_size(gtype)); } else - member_type = result_type = get_glsl_type(ctx, glsl_get_array_element(gtype)); + member_type = result_type = get_glsl_type(ctx, glsl_get_array_element(gtype), var->data.mode & (nir_var_shader_temp | nir_var_function_temp)); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, get_storage_class(var), result_type); @@ -2739,9 +2741,9 @@ emit_interpolate(struct ntv_context *ctx, nir_intrinsic_instr *intr) assert(glsl_get_vector_elements(gtype) == intr->num_components); assert(ptype == get_nir_alu_type(gtype)); if (intr->intrinsic == nir_intrinsic_interp_deref_at_centroid) - result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, gtype), ptr); + result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, gtype, false), ptr); else - result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, gtype), ptr, src1); + result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, gtype, false), ptr, src1); store_def(ctx, intr->def.index, result, ptype); } @@ -4252,7 +4254,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) case nir_var_shader_in: case nir_var_shader_out: base = get_src(ctx, &deref->parent, &atype); - type = get_glsl_type(ctx, deref->type); + type = get_glsl_type(ctx, deref->type, var->data.mode & (nir_var_shader_temp | nir_var_function_temp)); break; case nir_var_uniform: @@ -4297,7 +4299,7 @@ emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref) SpvId index = emit_uint_const(ctx, 32, deref->strct.index); SpvId type = (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)) ? get_bo_array_type(ctx, var) : - get_glsl_type(ctx, deref->type); + get_glsl_type(ctx, deref->type, var->data.mode & (nir_var_shader_temp | nir_var_function_temp)); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, storage_class, @@ -4611,10 +4613,11 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, const s ctx.have_spirv16 = spirv_version >= SPIRV_VERSION(1, 6); ctx.bindless_set_idx = sinfo->bindless_set_idx; - ctx.glsl_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); + ctx.glsl_types[0] = _mesa_pointer_hash_table_create(ctx.mem_ctx); + ctx.glsl_types[1] = _mesa_pointer_hash_table_create(ctx.mem_ctx); ctx.bo_array_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); ctx.bo_struct_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); - if (!ctx.glsl_types || !ctx.bo_array_types || !ctx.bo_struct_types || + if (!ctx.glsl_types[0] || !ctx.glsl_types[1] || !ctx.bo_array_types || !ctx.bo_struct_types || !_mesa_hash_table_init(&ctx.image_types, ctx.mem_ctx, _mesa_hash_pointer, _mesa_key_pointer_equal)) goto fail;