diff --git a/src/vulkan/runtime/vk_pipeline.c b/src/vulkan/runtime/vk_pipeline.c index dffa874dd3f..f44638c7c74 100644 --- a/src/vulkan/runtime/vk_pipeline.c +++ b/src/vulkan/runtime/vk_pipeline.c @@ -127,6 +127,46 @@ vk_get_subgroup_size(uint32_t spirv_version, } } +void +vk_set_subgroup_size(struct vk_device *device, + nir_shader *shader, + uint32_t spirv_version, + const void *info_pNext, + bool allow_varying, + bool require_full) +{ + struct vk_properties *properties = &device->physical->properties; + uint32_t req_subgroup_size = get_required_subgroup_size(info_pNext); + if (req_subgroup_size) { + assert(util_is_power_of_two_nonzero(req_subgroup_size)); + assert(req_subgroup_size >= 1 && req_subgroup_size <= 128); + shader->info.api_subgroup_size = req_subgroup_size; + shader->info.max_subgroup_size = req_subgroup_size; + shader->info.min_subgroup_size = req_subgroup_size; + } else if (allow_varying || spirv_version >= 0x10600) { + /* Starting with SPIR-V 1.6, varying subgroup size is the default */ + } else { + shader->info.api_subgroup_size = properties->subgroupSize; + if (require_full) { + assert(shader->info.stage == MESA_SHADER_COMPUTE || + shader->info.stage == MESA_SHADER_MESH || + shader->info.stage == MESA_SHADER_TASK); + shader->info.max_subgroup_size = properties->subgroupSize; + shader->info.min_subgroup_size = properties->subgroupSize; + } + } + + if (properties->maxSubgroupSize) { + assert(properties->minSubgroupSize); + shader->info.max_subgroup_size = + MIN2(shader->info.max_subgroup_size, properties->maxSubgroupSize); + shader->info.min_subgroup_size = + MAX2(shader->info.min_subgroup_size, properties->minSubgroupSize); + } + + assert(shader->info.max_subgroup_size >= shader->info.min_subgroup_size); +} + VkResult vk_pipeline_shader_stage_to_nir(struct vk_device *device, VkPipelineCreateFlags2KHR pipeline_flags, @@ -186,6 +226,13 @@ vk_pipeline_shader_stage_to_nir(struct vk_device *device, if (nir == NULL) return vk_errorf(device, VK_ERROR_UNKNOWN, "spirv_to_nir failed"); + vk_set_subgroup_size( + device, nir, + vk_spirv_version(spirv_data, spirv_size), + info->pNext, + info->flags & VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT, + info->flags & VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT); + if (pipeline_flags & VK_PIPELINE_CREATE_2_VIEW_INDEX_FROM_DEVICE_INDEX_BIT_KHR) NIR_PASS(_, nir, nir_lower_view_index_to_device_index); diff --git a/src/vulkan/runtime/vk_pipeline.h b/src/vulkan/runtime/vk_pipeline.h index ca520d9c74c..f9bc7a8874b 100644 --- a/src/vulkan/runtime/vk_pipeline.h +++ b/src/vulkan/runtime/vk_pipeline.h @@ -61,6 +61,16 @@ vk_get_subgroup_size(uint32_t spirv_version, bool allow_varying, bool require_full); +typedef struct nir_shader nir_shader; + +void +vk_set_subgroup_size(struct vk_device *device, + nir_shader *shader, + uint32_t spirv_version, + const void *info_pNext, + bool allow_varying, + bool require_full); + struct vk_pipeline_robustness_state { VkPipelineRobustnessBufferBehaviorEXT storage_buffers; VkPipelineRobustnessBufferBehaviorEXT uniform_buffers; diff --git a/src/vulkan/runtime/vk_shader.c b/src/vulkan/runtime/vk_shader.c index 0078693791a..20185b1bcaf 100644 --- a/src/vulkan/runtime/vk_shader.c +++ b/src/vulkan/runtime/vk_shader.c @@ -180,6 +180,13 @@ vk_shader_to_nir(struct vk_device *device, if (nir == NULL) return NULL; + vk_set_subgroup_size( + device, nir, + vk_spirv_version(info->pCode, info->codeSize), + info->pNext, + info->flags & VK_SHADER_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT, + info->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT); + if (ops->preprocess_nir != NULL) ops->preprocess_nir(device->physical, nir, rs);