From e8558de16fea201e47b614f57a8cea8e5703abcb Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Mon, 16 Jun 2025 17:02:50 -0400 Subject: [PATCH] vulkan/shader: Call vk_nir_lower_descriptor_heaps() Embedded samplers (if present) are passed to the driver as part of the vk_shader_compile_info Reviewed-by: Samuel Pitoiset Part-of: --- src/vulkan/runtime/vk_shader.c | 41 +++++++++++++++++++++++++++++----- src/vulkan/runtime/vk_shader.h | 4 ++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/vulkan/runtime/vk_shader.c b/src/vulkan/runtime/vk_shader.c index 0823db25141..cf2968f2b19 100644 --- a/src/vulkan/runtime/vk_shader.c +++ b/src/vulkan/runtime/vk_shader.c @@ -29,6 +29,7 @@ #include "vk_descriptor_set_layout.h" #include "vk_device.h" #include "vk_nir.h" +#include "vk_nir_lower_descriptor_heaps.h" #include "vk_physical_device.h" #include "vk_physical_device_features.h" #include "vk_pipeline.h" @@ -245,7 +246,8 @@ cmp_stage_idx(const void *_a, const void *_b) static nir_shader * vk_shader_to_nir(struct vk_device *device, const VkShaderCreateInfoEXT *info, - const struct vk_pipeline_robustness_state *rs) + const struct vk_pipeline_robustness_state *rs, + struct vk_sampler_state_array *embedded_samplers_out) { const struct vk_device_shader_ops *ops = device->shader_ops; const struct vk_properties *properties = &device->physical->properties; @@ -275,6 +277,19 @@ vk_shader_to_nir(struct vk_device *device, if (ops->preprocess_nir != NULL) ops->preprocess_nir(device->physical, nir, rs); + const VkShaderDescriptorSetAndBindingMappingInfoEXT *desc_map = + vk_find_struct_const(info->pNext, + SHADER_DESCRIPTOR_SET_AND_BINDING_MAPPING_INFO_EXT); + + bool heaps_progress = false; + NIR_PASS(heaps_progress, nir, vk_nir_lower_descriptor_heaps, + desc_map, embedded_samplers_out); + if (heaps_progress) { + NIR_PASS(_, nir, nir_remove_dead_variables, + nir_var_uniform | nir_var_image, NULL); + NIR_PASS(_, nir, nir_opt_dce); + } + return nir; } @@ -287,6 +302,7 @@ vk_shader_compile_info_init(struct vk_shader_compile_info *info, struct set_layouts *set_layouts, const VkShaderCreateInfoEXT *vk_info, const struct vk_pipeline_robustness_state *rs, + const struct vk_sampler_state_array *es, nir_shader *nir) { for (uint32_t sl = 0; sl < vk_info->setLayoutCount; sl++) { @@ -302,6 +318,8 @@ vk_shader_compile_info_init(struct vk_shader_compile_info *info, .robustness = rs, .set_layout_count = vk_info->setLayoutCount, .set_layouts = set_layouts->set_layouts, + .embedded_sampler_count = es->sampler_count, + .embedded_samplers = es->samplers, .push_constant_range_count = vk_info->pushConstantRangeCount, .push_constant_ranges = vk_info->pPushConstantRanges, }; @@ -601,8 +619,10 @@ vk_common_CreateShadersEXT(VkDevice _device, .idx = i, }; } else { + struct vk_sampler_state_array embedded_samplers = {}; nir_shader *nir = vk_shader_to_nir(device, vk_info, - &vk_robustness_disabled); + &vk_robustness_disabled, + &embedded_samplers); if (nir == NULL) { result = vk_errorf(device, VK_ERROR_UNKNOWN, "Failed to compile shader to NIR"); @@ -612,12 +632,16 @@ vk_common_CreateShadersEXT(VkDevice _device, struct vk_shader_compile_info info; struct set_layouts set_layouts; vk_shader_compile_info_init(&info, &set_layouts, - vk_info, &vk_robustness_disabled, nir); + vk_info, &vk_robustness_disabled, + &embedded_samplers, nir); struct vk_shader *shader; result = vk_compile_shaders(device, 1, &info, NULL /* state */, NULL /* features */, pAllocator, &shader); + + vk_sampler_state_array_finish(&embedded_samplers); + if (result != VK_SUCCESS) break; @@ -637,6 +661,7 @@ vk_common_CreateShadersEXT(VkDevice _device, if (linked_count > 0) { struct set_layouts set_layouts[VK_MAX_LINKED_SHADER_STAGES]; struct vk_shader_compile_info infos[VK_MAX_LINKED_SHADER_STAGES]; + struct vk_sampler_state_array embedded_samplers[VK_MAX_LINKED_SHADER_STAGES]; VkResult result = VK_SUCCESS; /* Sort so we guarantee the driver always gets them in-order */ @@ -644,12 +669,14 @@ vk_common_CreateShadersEXT(VkDevice _device, /* Memset for easy error handling */ memset(infos, 0, sizeof(infos)); + memset(embedded_samplers, 0, sizeof(embedded_samplers)); for (uint32_t l = 0; l < linked_count; l++) { const VkShaderCreateInfoEXT *vk_info = &pCreateInfos[linked[l].idx]; nir_shader *nir = vk_shader_to_nir(device, vk_info, - &vk_robustness_disabled); + &vk_robustness_disabled, + &embedded_samplers[l]); if (nir == NULL) { result = vk_errorf(device, VK_ERROR_UNKNOWN, "Failed to compile shader to NIR"); @@ -657,7 +684,8 @@ vk_common_CreateShadersEXT(VkDevice _device, } vk_shader_compile_info_init(&infos[l], &set_layouts[l], - vk_info, &vk_robustness_disabled, nir); + vk_info, &vk_robustness_disabled, + &embedded_samplers[l], nir); } if (result == VK_SUCCESS) { @@ -676,6 +704,9 @@ vk_common_CreateShadersEXT(VkDevice _device, } } + for (uint32_t l = 0; l < linked_count; l++) + vk_sampler_state_array_finish(&embedded_samplers[l]); + if (first_fail_or_success == VK_SUCCESS) first_fail_or_success = result; } diff --git a/src/vulkan/runtime/vk_shader.h b/src/vulkan/runtime/vk_shader.h index 7ff18dc1aca..4e0d4dc84f3 100644 --- a/src/vulkan/runtime/vk_shader.h +++ b/src/vulkan/runtime/vk_shader.h @@ -46,6 +46,7 @@ struct vk_features; struct vk_physical_device; struct vk_pipeline; struct vk_pipeline_robustness_state; +struct vk_sampler_state; bool vk_validate_shader_binaries(void); @@ -100,6 +101,9 @@ struct vk_shader_compile_info { uint32_t set_layout_count; struct vk_descriptor_set_layout * const *set_layouts; + uint32_t embedded_sampler_count; + const struct vk_sampler_state* embedded_samplers; + uint32_t push_constant_range_count; const VkPushConstantRange *push_constant_ranges; };