From aabdd0f5ca872779e35f24cf738ea0899e2285b3 Mon Sep 17 00:00:00 2001 From: Mary Guillemard Date: Wed, 5 Nov 2025 14:46:39 -0500 Subject: [PATCH] panvk: Move late lowering to panvk_compile_nir() This is needed for intershader optimization and GS lowering. We now pass all NIR variants with pan_compile_inputs to panvk_compile_shader and handle sysvals/push consts lowering in there. Signed-off-by: Mary Guillemard Reviewed-by: Faith Ekstrand Reviewed-by: Christoph Pillmayer Acked-by: Eric R. Smith Part-of: --- src/panfrost/vulkan/panvk_vX_shader.c | 45 ++++++++++++++------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/panfrost/vulkan/panvk_vX_shader.c b/src/panfrost/vulkan/panvk_vX_shader.c index 89ee3ae3ebc..47027131ba6 100644 --- a/src/panfrost/vulkan/panvk_vX_shader.c +++ b/src/panfrost/vulkan/panvk_vX_shader.c @@ -769,7 +769,6 @@ panvk_lower_nir(struct panvk_device *dev, nir_shader *nir, uint32_t set_layout_count, struct vk_descriptor_set_layout *const *set_layouts, const struct vk_pipeline_robustness_state *rs, - const uint32_t *noperspective_varyings, const struct vk_graphics_pipeline_state *state, const struct pan_compile_inputs *compile_input, struct panvk_shader_variant *shader) @@ -918,24 +917,36 @@ panvk_lower_nir(struct panvk_device *dev, nir_shader *nir, nir_var_shader_in | nir_var_shader_out, UINT32_MAX); NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in | nir_var_shader_out, glsl_type_size, nir_lower_io_use_interpolated_input_intrinsics); +} + +static VkResult +panvk_compile_nir(struct panvk_device *dev, nir_shader *nir, + VkShaderCreateFlagsEXT shader_flags, + struct pan_compile_inputs *compile_input, + const struct vk_graphics_pipeline_state *state, + const uint32_t *noperspective_varyings, + struct panvk_shader_variant *shader) +{ + const bool dump_asm = + shader_flags & VK_SHADER_CREATE_CAPTURE_INTERNAL_REPRESENTATIONS_BIT_MESA; pan_postprocess_nir(nir, compile_input->gpu_id); pan_nir_lower_texture_late(nir, compile_input->gpu_id); - if (stage == MESA_SHADER_VERTEX) + if (nir->info.stage == MESA_SHADER_VERTEX) NIR_PASS(_, nir, nir_shader_intrinsics_pass, panvk_lower_load_vs_input, nir_metadata_control_flow, NULL); - else if (stage == MESA_SHADER_FRAGMENT) + else if (nir->info.stage == MESA_SHADER_FRAGMENT) NIR_PASS(_, nir, nir_shader_intrinsics_pass, panvk_lower_load_fs_input, nir_metadata_control_flow, NULL); /* since valhall, panvk_per_arch(nir_lower_descriptors) separates the * driver set and the user sets, and does not need pan_nir_lower_image_index */ - if (PAN_ARCH < 9 && stage == MESA_SHADER_VERTEX) + if (PAN_ARCH < 9 && nir->info.stage == MESA_SHADER_VERTEX) NIR_PASS(_, nir, pan_nir_lower_image_index, MAX_VS_ATTRIBS); - if (noperspective_varyings && stage == MESA_SHADER_VERTEX) { + if (noperspective_varyings && nir->info.stage == MESA_SHADER_VERTEX) { NIR_PASS(_, nir, nir_inline_sysval, nir_intrinsic_load_noperspective_varyings_pan, *noperspective_varyings); @@ -950,17 +961,6 @@ panvk_lower_nir(struct panvk_device *dev, nir_shader *nir, nir_metadata_control_flow, &lower_sysvals_ctx); lower_load_push_consts(nir, shader); -} - -static VkResult -panvk_compile_nir(struct panvk_device *dev, nir_shader *nir, - VkShaderCreateFlagsEXT shader_flags, - struct pan_compile_inputs *compile_input, - struct panvk_shader_variant *shader) -{ - const bool dump_asm = - shader_flags & VK_SHADER_CREATE_CAPTURE_INTERNAL_REPRESENTATIONS_BIT_MESA; - /* Allow the remaining FAU space to be filled with constants. */ compile_input->fau_consts.max_amount = 2 * (FAU_WORD_COUNT - shader->fau.total_count); @@ -1359,13 +1359,14 @@ panvk_compile_shader(struct panvk_device *dev, panvk_lower_nir(dev, nir_variants[v], info->set_layout_count, info->set_layouts, info->robustness, - noperspective_varyings, state, &input_variants[v], - variant); + state, &input_variants[v], variant); variant->own_bin = true; result = panvk_compile_nir(dev, nir_variants[v], info->flags, - &input_variants[v], variant); + &input_variants[v], state, + noperspective_varyings, + variant); if (result != VK_SUCCESS) { panvk_shader_destroy(&dev->vk, &shader->vk, pAllocator); return result; @@ -1384,8 +1385,7 @@ panvk_compile_shader(struct panvk_device *dev, variant->own_bin = true; panvk_lower_nir(dev, nir, info->set_layout_count, info->set_layouts, - info->robustness, noperspective_varyings, state, &inputs, - variant); + info->robustness, state, &inputs, variant); #if PAN_ARCH >= 9 if (info->stage == MESA_SHADER_FRAGMENT) @@ -1393,7 +1393,8 @@ panvk_compile_shader(struct panvk_device *dev, inputs.valhall.use_ld_var_buf = panvk_use_ld_var_buf(variant); #endif - result = panvk_compile_nir(dev, nir, info->flags, &inputs, variant); + result = panvk_compile_nir(dev, nir, info->flags, &inputs, state, + noperspective_varyings, variant); if (result != VK_SUCCESS) { panvk_shader_destroy(&dev->vk, &shader->vk, pAllocator); return result;