diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index ddb100c3943..43a02680fa6 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -5664,6 +5664,17 @@ lower_vri_to_var(nir_shader *nir) nir_metadata_control_flow, NULL); } +static void +shared_type_info(const struct glsl_type *type, unsigned *size, unsigned *align) +{ + assert(glsl_type_is_vector_or_scalar(type)); + + uint32_t comp_size = glsl_type_is_boolean(type) + ? 4 : glsl_get_bit_size(type) / 8; + unsigned length = glsl_get_vector_elements(type); + *size = comp_size * length, + *align = comp_size * (length == 3 ? 4 : length); +} /* this is the bare minimum required to make vtn shaders work with ntv */ void ntv_shader_prepare(nir_shader *nir) @@ -5675,6 +5686,13 @@ ntv_shader_prepare(nir_shader *nir) NIR_PASS(_, nir, nir_split_per_member_structs); NIR_PASS(_, nir, nir_lower_returns); NIR_PASS(_, nir, nir_inline_functions); + if (nir->info.stage == MESA_SHADER_COMPUTE || + nir->info.stage == MESA_SHADER_TASK || + nir->info.stage == MESA_SHADER_MESH) { + nir_variable_mode modes = nir_var_mem_shared | nir_var_mem_task_payload; + NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, modes, shared_type_info); + NIR_PASS(_, nir, nir_lower_explicit_io, modes, nir_address_format_32bit_offset); + } nir_cleanup_functions(nir); optimize_nir(nir); NIR_PASS(_, nir, nir_lower_variable_initializers, nir_var_shader_temp);