diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 1c72505f45c..42ae81124c2 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -25,6 +25,7 @@ #include "spirv_builder.h" #include "nir.h" +#include "nir_builder.h" #include "pipe/p_state.h" #include "util/u_math.h" #include "util/u_memory.h" @@ -5507,11 +5508,74 @@ optimize_nir(struct nir_shader *nir) NIR_PASS(progress, nir, nir_opt_shrink_vectors, true); } +/* deref casts produce results with 2 components; direct derefs have 1 component */ +static void +fixup_deref_components(nir_deref_instr *deref) +{ + nir_foreach_use(src, &deref->def) { + nir_instr *user_instr = nir_src_parent_instr(src); + if (user_instr->type != nir_instr_type_deref) + continue; + nir_deref_instr *user_deref = nir_instr_as_deref(user_instr); + user_deref->def.num_components = 1; + fixup_deref_components(user_deref); + } +} + +/* convert (vulkan_resource_index -> load_vulkan_descriptor -> deref_cast) into deref_var */ +static bool +lower_vri_to_var_instr(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_deref) + return false; + + nir_deref_instr *deref = nir_instr_as_deref(instr); + if (deref->deref_type != nir_deref_type_cast) + return false; + if (!(deref->modes & (nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_uniform | nir_var_image))) + return false; + if (!nir_src_is_intrinsic(deref->parent)) + return false; + nir_intrinsic_instr *lvd = nir_src_as_intrinsic(deref->parent); + if (lvd->intrinsic != nir_intrinsic_load_vulkan_descriptor) + return false; + nir_intrinsic_instr *vri = nir_src_as_intrinsic(lvd->src[0]); + if (vri->intrinsic != nir_intrinsic_vulkan_resource_index) + return false; + + int desc_set = nir_intrinsic_desc_set(vri); + int binding = nir_intrinsic_binding(vri); + nir_variable *var = NULL; + nir_foreach_variable_with_modes(i, b->shader, deref->modes) { + if (i->data.descriptor_set == desc_set && i->data.binding == binding) { + var = i; + break; + } + } + if (!var) + return false; + + b->cursor = nir_after_instr(instr); + nir_deref_instr *var_deref = nir_build_deref_var(b, var); + fixup_deref_components(deref); + nir_def_rewrite_uses_after_instr(&deref->def, &var_deref->def, instr); + + return true; +} + +static bool +lower_vri_to_var(nir_shader *nir) +{ + return nir_shader_instructions_pass(nir, lower_vri_to_var_instr, + nir_metadata_control_flow, NULL); +} + /* this is the bare minimum required to make vtn shaders work with ntv */ void ntv_shader_prepare(nir_shader *nir) { struct nir_lower_compute_system_values_options cs_options = {0}; + NIR_PASS(_, nir, lower_vri_to_var); NIR_PASS(_, nir, nir_lower_system_values); NIR_PASS(_, nir, nir_lower_compute_system_values, &cs_options); NIR_PASS(_, nir, nir_split_per_member_structs);