From a98b44cd34faff3d258cd7270e747ca8eff28f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Sch=C3=BCrmann?= Date: Fri, 14 Apr 2023 12:00:03 +0200 Subject: [PATCH] radv/rt: add shader stage indices to radv_ray_tracing_group Part-of: --- src/amd/vulkan/radv_pipeline_rt.c | 48 +++++++++++++++++++++++++++++-- src/amd/vulkan/radv_private.h | 4 +++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 223e0b2fdc8..2bc06c9198f 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -132,6 +132,50 @@ radv_create_group_handles(struct radv_device *device, return VK_SUCCESS; } +static VkResult +radv_rt_fill_group_info(struct radv_device *device, + const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + struct radv_ray_tracing_group *groups) +{ + VkResult result = radv_create_group_handles(device, pCreateInfo, groups); + + uint32_t idx; + for (idx = 0; idx < pCreateInfo->groupCount; idx++) { + groups[idx].type = pCreateInfo->pGroups[idx].type; + if (groups[idx].type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) + groups[idx].recursive_shader = pCreateInfo->pGroups[idx].generalShader; + else + groups[idx].recursive_shader = pCreateInfo->pGroups[idx].closestHitShader; + groups[idx].any_hit_shader = pCreateInfo->pGroups[idx].anyHitShader; + groups[idx].intersection_shader = pCreateInfo->pGroups[idx].intersectionShader; + } + + /* copy and adjust library groups (incl. handles) */ + if (pCreateInfo->pLibraryInfo) { + unsigned stage_count = pCreateInfo->stageCount; + for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) { + RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]); + struct radv_ray_tracing_lib_pipeline *library_pipeline = + radv_pipeline_to_ray_tracing_lib(pipeline); + + for (unsigned j = 0; j < library_pipeline->group_count; ++j) { + struct radv_ray_tracing_group *dst = &groups[idx + j]; + *dst = library_pipeline->groups[j]; + if (dst->recursive_shader != VK_SHADER_UNUSED_KHR) + dst->recursive_shader += stage_count; + if (dst->any_hit_shader != VK_SHADER_UNUSED_KHR) + dst->any_hit_shader += stage_count; + if (dst->intersection_shader != VK_SHADER_UNUSED_KHR) + dst->intersection_shader += stage_count; + } + idx += library_pipeline->group_count; + stage_count += library_pipeline->stage_count; + } + } + + return result; +} + static VkRayTracingPipelineCreateInfoKHR radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo) { @@ -349,7 +393,7 @@ radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache, pipeline->ctx = ralloc_context(NULL); - result = radv_create_group_handles(device, &local_create_info, pipeline->groups); + result = radv_rt_fill_group_info(device, pCreateInfo, pipeline->groups); if (result != VK_SUCCESS) goto pipeline_fail; @@ -555,7 +599,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, radv_pipeline_init(device, &rt_pipeline->base.base, RADV_PIPELINE_RAY_TRACING); rt_pipeline->group_count = local_create_info.groupCount; - result = radv_create_group_handles(device, &local_create_info, rt_pipeline->groups); + result = radv_rt_fill_group_info(device, pCreateInfo, rt_pipeline->groups); if (result != VK_SUCCESS) goto pipeline_fail; diff --git a/src/amd/vulkan/radv_private.h b/src/amd/vulkan/radv_private.h index 40ae52149c8..35c2296a4fd 100644 --- a/src/amd/vulkan/radv_private.h +++ b/src/amd/vulkan/radv_private.h @@ -2285,6 +2285,10 @@ struct radv_compute_pipeline { }; struct radv_ray_tracing_group { + VkRayTracingShaderGroupTypeKHR type; + uint32_t recursive_shader; /* generalShader or closestHitShader */ + uint32_t any_hit_shader; + uint32_t intersection_shader; struct radv_pipeline_group_handle handle; struct radv_pipeline_shader_stack_size stack_size; };