vulkan: set nir subgroup size shader info

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37258>
This commit is contained in:
Georg Lehmann 2025-09-09 17:46:07 +02:00 committed by Marge Bot
parent ce91b0be08
commit d807f5a351
3 changed files with 64 additions and 0 deletions

View file

@ -127,6 +127,46 @@ vk_get_subgroup_size(uint32_t spirv_version,
}
}
void
vk_set_subgroup_size(struct vk_device *device,
nir_shader *shader,
uint32_t spirv_version,
const void *info_pNext,
bool allow_varying,
bool require_full)
{
struct vk_properties *properties = &device->physical->properties;
uint32_t req_subgroup_size = get_required_subgroup_size(info_pNext);
if (req_subgroup_size) {
assert(util_is_power_of_two_nonzero(req_subgroup_size));
assert(req_subgroup_size >= 1 && req_subgroup_size <= 128);
shader->info.api_subgroup_size = req_subgroup_size;
shader->info.max_subgroup_size = req_subgroup_size;
shader->info.min_subgroup_size = req_subgroup_size;
} else if (allow_varying || spirv_version >= 0x10600) {
/* Starting with SPIR-V 1.6, varying subgroup size is the default */
} else {
shader->info.api_subgroup_size = properties->subgroupSize;
if (require_full) {
assert(shader->info.stage == MESA_SHADER_COMPUTE ||
shader->info.stage == MESA_SHADER_MESH ||
shader->info.stage == MESA_SHADER_TASK);
shader->info.max_subgroup_size = properties->subgroupSize;
shader->info.min_subgroup_size = properties->subgroupSize;
}
}
if (properties->maxSubgroupSize) {
assert(properties->minSubgroupSize);
shader->info.max_subgroup_size =
MIN2(shader->info.max_subgroup_size, properties->maxSubgroupSize);
shader->info.min_subgroup_size =
MAX2(shader->info.min_subgroup_size, properties->minSubgroupSize);
}
assert(shader->info.max_subgroup_size >= shader->info.min_subgroup_size);
}
VkResult
vk_pipeline_shader_stage_to_nir(struct vk_device *device,
VkPipelineCreateFlags2KHR pipeline_flags,
@ -186,6 +226,13 @@ vk_pipeline_shader_stage_to_nir(struct vk_device *device,
if (nir == NULL)
return vk_errorf(device, VK_ERROR_UNKNOWN, "spirv_to_nir failed");
vk_set_subgroup_size(
device, nir,
vk_spirv_version(spirv_data, spirv_size),
info->pNext,
info->flags & VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT,
info->flags & VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT);
if (pipeline_flags & VK_PIPELINE_CREATE_2_VIEW_INDEX_FROM_DEVICE_INDEX_BIT_KHR)
NIR_PASS(_, nir, nir_lower_view_index_to_device_index);

View file

@ -61,6 +61,16 @@ vk_get_subgroup_size(uint32_t spirv_version,
bool allow_varying,
bool require_full);
typedef struct nir_shader nir_shader;
void
vk_set_subgroup_size(struct vk_device *device,
nir_shader *shader,
uint32_t spirv_version,
const void *info_pNext,
bool allow_varying,
bool require_full);
struct vk_pipeline_robustness_state {
VkPipelineRobustnessBufferBehaviorEXT storage_buffers;
VkPipelineRobustnessBufferBehaviorEXT uniform_buffers;

View file

@ -180,6 +180,13 @@ vk_shader_to_nir(struct vk_device *device,
if (nir == NULL)
return NULL;
vk_set_subgroup_size(
device, nir,
vk_spirv_version(info->pCode, info->codeSize),
info->pNext,
info->flags & VK_SHADER_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT,
info->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT);
if (ops->preprocess_nir != NULL)
ops->preprocess_nir(device->physical, nir, rs);