diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 76d13481f96..744532b0d41 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -1464,17 +1464,39 @@ radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, VK_FROM_HANDLE(radv_pipeline, pipeline, _pipeline); struct radv_ray_tracing_pipeline *rt_pipeline = radv_pipeline_to_ray_tracing(pipeline); struct radv_ray_tracing_group *rt_group = &rt_pipeline->groups[group]; + + struct radv_ray_tracing_stage *shader_stage; + switch (groupShader) { case VK_SHADER_GROUP_SHADER_GENERAL_KHR: case VK_SHADER_GROUP_SHADER_CLOSEST_HIT_KHR: - return rt_pipeline->stages[rt_group->recursive_shader].stack_size; + shader_stage = &rt_pipeline->stages[rt_group->recursive_shader]; + break; case VK_SHADER_GROUP_SHADER_ANY_HIT_KHR: - return rt_pipeline->stages[rt_group->any_hit_shader].stack_size; + /* If the any-hit shader is inlined into an intersection shader, there is no stack specific to the any-hit shader + * and all stack will be allocated for the intersection shader instead. + */ + if (rt_group->intersection_shader != VK_SHADER_UNUSED_KHR) + return 0; + shader_stage = &rt_pipeline->stages[rt_group->any_hit_shader]; + break; case VK_SHADER_GROUP_SHADER_INTERSECTION_KHR: - return rt_pipeline->stages[rt_group->intersection_shader].stack_size; + shader_stage = &rt_pipeline->stages[rt_group->intersection_shader]; + break; default: return 0; } + + uint32_t stack_size = shader_stage->stack_size; + /* Applications need to allocate stack for the traversal shader, too. The API doesn't intend for a constant + * traversal stack size, so add the stack size to every shader potentially called by the traversal shader. + * Applications are expected to max() shader stages together, so this shouldn't result in any unnecessary stack + * usage. + */ + if (shader_stage->stage == MESA_SHADER_CLOSEST_HIT || shader_stage->stage == MESA_SHADER_ANY_HIT || + shader_stage->stage == MESA_SHADER_INTERSECTION || shader_stage->stage == MESA_SHADER_MISS) + stack_size += rt_pipeline->traversal_stack_size; + return stack_size; } VKAPI_ATTR VkResult VKAPI_CALL