panvk: Move late lowering to panvk_compile_nir()

This is needed for intershader optimization and GS lowering.

We now pass all NIR variants with pan_compile_inputs to
panvk_compile_shader and handle sysvals/push consts lowering in there.

Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Reviewed-by: Christoph Pillmayer <christoph.pillmayer@arm.com>
Acked-by: Eric R. Smith <eric.smith@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38821>
This commit is contained in:
Mary Guillemard 2025-11-05 14:46:39 -05:00 committed by Marge Bot
parent f4470dd7d7
commit aabdd0f5ca

View file

@ -769,7 +769,6 @@ panvk_lower_nir(struct panvk_device *dev, nir_shader *nir,
uint32_t set_layout_count,
struct vk_descriptor_set_layout *const *set_layouts,
const struct vk_pipeline_robustness_state *rs,
const uint32_t *noperspective_varyings,
const struct vk_graphics_pipeline_state *state,
const struct pan_compile_inputs *compile_input,
struct panvk_shader_variant *shader)
@ -918,24 +917,36 @@ panvk_lower_nir(struct panvk_device *dev, nir_shader *nir,
nir_var_shader_in | nir_var_shader_out, UINT32_MAX);
NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in | nir_var_shader_out,
glsl_type_size, nir_lower_io_use_interpolated_input_intrinsics);
}
static VkResult
panvk_compile_nir(struct panvk_device *dev, nir_shader *nir,
VkShaderCreateFlagsEXT shader_flags,
struct pan_compile_inputs *compile_input,
const struct vk_graphics_pipeline_state *state,
const uint32_t *noperspective_varyings,
struct panvk_shader_variant *shader)
{
const bool dump_asm =
shader_flags & VK_SHADER_CREATE_CAPTURE_INTERNAL_REPRESENTATIONS_BIT_MESA;
pan_postprocess_nir(nir, compile_input->gpu_id);
pan_nir_lower_texture_late(nir, compile_input->gpu_id);
if (stage == MESA_SHADER_VERTEX)
if (nir->info.stage == MESA_SHADER_VERTEX)
NIR_PASS(_, nir, nir_shader_intrinsics_pass, panvk_lower_load_vs_input,
nir_metadata_control_flow, NULL);
else if (stage == MESA_SHADER_FRAGMENT)
else if (nir->info.stage == MESA_SHADER_FRAGMENT)
NIR_PASS(_, nir, nir_shader_intrinsics_pass, panvk_lower_load_fs_input,
nir_metadata_control_flow, NULL);
/* since valhall, panvk_per_arch(nir_lower_descriptors) separates the
* driver set and the user sets, and does not need pan_nir_lower_image_index
*/
if (PAN_ARCH < 9 && stage == MESA_SHADER_VERTEX)
if (PAN_ARCH < 9 && nir->info.stage == MESA_SHADER_VERTEX)
NIR_PASS(_, nir, pan_nir_lower_image_index, MAX_VS_ATTRIBS);
if (noperspective_varyings && stage == MESA_SHADER_VERTEX) {
if (noperspective_varyings && nir->info.stage == MESA_SHADER_VERTEX) {
NIR_PASS(_, nir, nir_inline_sysval,
nir_intrinsic_load_noperspective_varyings_pan,
*noperspective_varyings);
@ -950,17 +961,6 @@ panvk_lower_nir(struct panvk_device *dev, nir_shader *nir,
nir_metadata_control_flow, &lower_sysvals_ctx);
lower_load_push_consts(nir, shader);
}
static VkResult
panvk_compile_nir(struct panvk_device *dev, nir_shader *nir,
VkShaderCreateFlagsEXT shader_flags,
struct pan_compile_inputs *compile_input,
struct panvk_shader_variant *shader)
{
const bool dump_asm =
shader_flags & VK_SHADER_CREATE_CAPTURE_INTERNAL_REPRESENTATIONS_BIT_MESA;
/* Allow the remaining FAU space to be filled with constants. */
compile_input->fau_consts.max_amount =
2 * (FAU_WORD_COUNT - shader->fau.total_count);
@ -1359,13 +1359,14 @@ panvk_compile_shader(struct panvk_device *dev,
panvk_lower_nir(dev, nir_variants[v], info->set_layout_count,
info->set_layouts, info->robustness,
noperspective_varyings, state, &input_variants[v],
variant);
state, &input_variants[v], variant);
variant->own_bin = true;
result = panvk_compile_nir(dev, nir_variants[v], info->flags,
&input_variants[v], variant);
&input_variants[v], state,
noperspective_varyings,
variant);
if (result != VK_SUCCESS) {
panvk_shader_destroy(&dev->vk, &shader->vk, pAllocator);
return result;
@ -1384,8 +1385,7 @@ panvk_compile_shader(struct panvk_device *dev,
variant->own_bin = true;
panvk_lower_nir(dev, nir, info->set_layout_count, info->set_layouts,
info->robustness, noperspective_varyings, state, &inputs,
variant);
info->robustness, state, &inputs, variant);
#if PAN_ARCH >= 9
if (info->stage == MESA_SHADER_FRAGMENT)
@ -1393,7 +1393,8 @@ panvk_compile_shader(struct panvk_device *dev,
inputs.valhall.use_ld_var_buf = panvk_use_ld_var_buf(variant);
#endif
result = panvk_compile_nir(dev, nir, info->flags, &inputs, variant);
result = panvk_compile_nir(dev, nir, info->flags, &inputs, state,
noperspective_varyings, variant);
if (result != VK_SUCCESS) {
panvk_shader_destroy(&dev->vk, &shader->vk, pAllocator);
return result;