diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index cf377eff545..807186ad946 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -38,6 +38,8 @@ struct ntv_context { struct spirv_builder builder; + struct hash_table *glsl_types; + SpvId GLSL_std_450; gl_shader_stage stage; @@ -358,8 +360,17 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) glsl_get_vector_elements(type)), glsl_get_matrix_columns(type)); + /* Aggregate types aren't cached in spirv_builder, so let's cache + * them here instead. + */ + + struct hash_entry *entry = + _mesa_hash_table_search(ctx->glsl_types, type); + if (entry) + return (SpvId)(uintptr_t)entry->data; + + SpvId ret; if (glsl_type_is_array(type)) { - SpvId ret; SpvId element_type = get_glsl_type(ctx, glsl_get_array_element(type)); if (glsl_type_is_unsized_array(type)) ret = spirv_builder_type_runtime_array(&ctx->builder, element_type); @@ -373,21 +384,19 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type) } if (stride) spirv_builder_emit_array_stride(&ctx->builder, ret, stride); - return ret; - } - if (glsl_type_is_struct_or_ifc(type)) { + } else if (glsl_type_is_struct_or_ifc(type)) { SpvId types[glsl_get_length(type)]; for (unsigned i = 0; i < glsl_get_length(type); i++) types[i] = get_glsl_type(ctx, glsl_get_struct_field(type, i)); - SpvId ret = spirv_builder_type_struct(&ctx->builder, - types, - glsl_get_length(type)); + ret = spirv_builder_type_struct(&ctx->builder, types, + glsl_get_length(type)); for (unsigned i = 0; i < glsl_get_length(type); i++) spirv_builder_emit_member_offset(&ctx->builder, ret, i, glsl_get_struct_field_offset(type, i)); - return ret; - } + } else + unreachable("Unhandled GLSL type"); - unreachable("we shouldn't get here, I think..."); + _mesa_hash_table_insert(ctx->glsl_types, type, (void *)(uintptr_t)ret); + return ret; } static void @@ -3563,6 +3572,10 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info, bool spir ctx.builder.mem_ctx = ctx.mem_ctx; ctx.spirv_15 = spirv_15; + ctx.glsl_types = _mesa_pointer_hash_table_create(ctx.mem_ctx); + if (!ctx.glsl_types) + goto fail; + switch (s->info.stage) { case MESA_SHADER_VERTEX: case MESA_SHADER_FRAGMENT: