vulkan/shader: Call vk_nir_lower_descriptor_heaps()

Embedded samplers (if present) are passed to the driver as part of the
vk_shader_compile_info

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40649>
This commit is contained in:
Faith Ekstrand 2025-06-16 17:02:50 -04:00 committed by Marge Bot
parent f41dd1d157
commit e8558de16f
2 changed files with 40 additions and 5 deletions

View file

@ -29,6 +29,7 @@
#include "vk_descriptor_set_layout.h"
#include "vk_device.h"
#include "vk_nir.h"
#include "vk_nir_lower_descriptor_heaps.h"
#include "vk_physical_device.h"
#include "vk_physical_device_features.h"
#include "vk_pipeline.h"
@ -245,7 +246,8 @@ cmp_stage_idx(const void *_a, const void *_b)
static nir_shader *
vk_shader_to_nir(struct vk_device *device,
const VkShaderCreateInfoEXT *info,
const struct vk_pipeline_robustness_state *rs)
const struct vk_pipeline_robustness_state *rs,
struct vk_sampler_state_array *embedded_samplers_out)
{
const struct vk_device_shader_ops *ops = device->shader_ops;
const struct vk_properties *properties = &device->physical->properties;
@ -275,6 +277,19 @@ vk_shader_to_nir(struct vk_device *device,
if (ops->preprocess_nir != NULL)
ops->preprocess_nir(device->physical, nir, rs);
const VkShaderDescriptorSetAndBindingMappingInfoEXT *desc_map =
vk_find_struct_const(info->pNext,
SHADER_DESCRIPTOR_SET_AND_BINDING_MAPPING_INFO_EXT);
bool heaps_progress = false;
NIR_PASS(heaps_progress, nir, vk_nir_lower_descriptor_heaps,
desc_map, embedded_samplers_out);
if (heaps_progress) {
NIR_PASS(_, nir, nir_remove_dead_variables,
nir_var_uniform | nir_var_image, NULL);
NIR_PASS(_, nir, nir_opt_dce);
}
return nir;
}
@ -287,6 +302,7 @@ vk_shader_compile_info_init(struct vk_shader_compile_info *info,
struct set_layouts *set_layouts,
const VkShaderCreateInfoEXT *vk_info,
const struct vk_pipeline_robustness_state *rs,
const struct vk_sampler_state_array *es,
nir_shader *nir)
{
for (uint32_t sl = 0; sl < vk_info->setLayoutCount; sl++) {
@ -302,6 +318,8 @@ vk_shader_compile_info_init(struct vk_shader_compile_info *info,
.robustness = rs,
.set_layout_count = vk_info->setLayoutCount,
.set_layouts = set_layouts->set_layouts,
.embedded_sampler_count = es->sampler_count,
.embedded_samplers = es->samplers,
.push_constant_range_count = vk_info->pushConstantRangeCount,
.push_constant_ranges = vk_info->pPushConstantRanges,
};
@ -601,8 +619,10 @@ vk_common_CreateShadersEXT(VkDevice _device,
.idx = i,
};
} else {
struct vk_sampler_state_array embedded_samplers = {};
nir_shader *nir = vk_shader_to_nir(device, vk_info,
&vk_robustness_disabled);
&vk_robustness_disabled,
&embedded_samplers);
if (nir == NULL) {
result = vk_errorf(device, VK_ERROR_UNKNOWN,
"Failed to compile shader to NIR");
@ -612,12 +632,16 @@ vk_common_CreateShadersEXT(VkDevice _device,
struct vk_shader_compile_info info;
struct set_layouts set_layouts;
vk_shader_compile_info_init(&info, &set_layouts,
vk_info, &vk_robustness_disabled, nir);
vk_info, &vk_robustness_disabled,
&embedded_samplers, nir);
struct vk_shader *shader;
result = vk_compile_shaders(device, 1, &info,
NULL /* state */, NULL /* features */,
pAllocator, &shader);
vk_sampler_state_array_finish(&embedded_samplers);
if (result != VK_SUCCESS)
break;
@ -637,6 +661,7 @@ vk_common_CreateShadersEXT(VkDevice _device,
if (linked_count > 0) {
struct set_layouts set_layouts[VK_MAX_LINKED_SHADER_STAGES];
struct vk_shader_compile_info infos[VK_MAX_LINKED_SHADER_STAGES];
struct vk_sampler_state_array embedded_samplers[VK_MAX_LINKED_SHADER_STAGES];
VkResult result = VK_SUCCESS;
/* Sort so we guarantee the driver always gets them in-order */
@ -644,12 +669,14 @@ vk_common_CreateShadersEXT(VkDevice _device,
/* Memset for easy error handling */
memset(infos, 0, sizeof(infos));
memset(embedded_samplers, 0, sizeof(embedded_samplers));
for (uint32_t l = 0; l < linked_count; l++) {
const VkShaderCreateInfoEXT *vk_info = &pCreateInfos[linked[l].idx];
nir_shader *nir = vk_shader_to_nir(device, vk_info,
&vk_robustness_disabled);
&vk_robustness_disabled,
&embedded_samplers[l]);
if (nir == NULL) {
result = vk_errorf(device, VK_ERROR_UNKNOWN,
"Failed to compile shader to NIR");
@ -657,7 +684,8 @@ vk_common_CreateShadersEXT(VkDevice _device,
}
vk_shader_compile_info_init(&infos[l], &set_layouts[l],
vk_info, &vk_robustness_disabled, nir);
vk_info, &vk_robustness_disabled,
&embedded_samplers[l], nir);
}
if (result == VK_SUCCESS) {
@ -676,6 +704,9 @@ vk_common_CreateShadersEXT(VkDevice _device,
}
}
for (uint32_t l = 0; l < linked_count; l++)
vk_sampler_state_array_finish(&embedded_samplers[l]);
if (first_fail_or_success == VK_SUCCESS)
first_fail_or_success = result;
}

View file

@ -46,6 +46,7 @@ struct vk_features;
struct vk_physical_device;
struct vk_pipeline;
struct vk_pipeline_robustness_state;
struct vk_sampler_state;
bool vk_validate_shader_binaries(void);
@ -100,6 +101,9 @@ struct vk_shader_compile_info {
uint32_t set_layout_count;
struct vk_descriptor_set_layout * const *set_layouts;
uint32_t embedded_sampler_count;
const struct vk_sampler_state* embedded_samplers;
uint32_t push_constant_range_count;
const VkPushConstantRange *push_constant_ranges;
};