From 4c3a74bebe5b35612bdbf44b235c9a1423bd2c50 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 25 Mar 2026 10:38:14 +0000 Subject: [PATCH] radv: move radv_shader_create out of radv_rt_nir_to_asm Signed-off-by: Rhys Perry Reviewed-by: Samuel Pitoiset Part-of: --- src/amd/vulkan/radv_pipeline_rt.c | 118 ++++++++++++++++-------------- 1 file changed, 62 insertions(+), 56 deletions(-) diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index fb98e33237d..c90787513e2 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -370,19 +370,16 @@ move_rt_instructions(nir_shader *shader) return nir_progress(progress, nir_shader_get_entrypoint(shader), nir_metadata_control_flow); } -static VkResult -radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, - struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode, - struct radv_shader_stage *stage, uint32_t *payload_size, uint32_t *hit_attrib_size, - uint32_t *stack_size, struct radv_ray_tracing_stage_info *stage_info, - const struct radv_ray_tracing_stage_info *traversal_stage_info, - struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache, - bool has_position_fetch, struct radv_shader **out_shader) +static void +radv_rt_nir_to_asm(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, + enum radv_rt_lowering_mode mode, struct radv_shader_stage *stage, uint32_t *payload_size, + uint32_t *hit_attrib_size, struct radv_ray_tracing_stage_info *stage_info, + const struct radv_ray_tracing_stage_info *traversal_stage_info, bool has_position_fetch, + struct radv_shader_binary **binary, struct radv_shader_debug_info *debug) { struct radv_physical_device *pdev = radv_device_physical(device); struct radv_instance *instance = radv_physical_device_instance(pdev); - struct radv_shader_binary *binary; bool keep_executable_info = radv_pipeline_capture_shaders(device, pipeline->base.base.create_flags); bool keep_statistic_info = radv_pipeline_capture_shader_stats(device, pipeline->base.base.create_flags); @@ -436,8 +433,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, unsigned num_shaders = num_resume_shaders + 1; nir_shader **shaders = ralloc_array(mem_ctx, nir_shader *, num_shaders); - if (!shaders) - return VK_ERROR_OUT_OF_HOST_MEMORY; shaders[0] = stage->nir; for (uint32_t i = 0; i < num_resume_shaders; i++) @@ -488,13 +483,10 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, radv_gather_unused_args(stage_info, temp_stage.nir); } - bool dump_shader = radv_can_dump_shader(device, stage->nir); - bool dump_nir = dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR); - bool replayable = (pipeline->base.base.create_flags & - VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) && - !radv_is_traversal_shader(stage->nir); + debug->dump_shader = radv_can_dump_shader(device, stage->nir); + bool dump_nir = debug->dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR); - if (dump_shader) { + if (debug->dump_shader) { simple_mtx_lock(&instance->shader_dump_mtx); if (dump_nir) { @@ -504,49 +496,63 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, } /* Compile NIR shader to AMD assembly. */ - binary = + *binary = radv_shader_nir_to_asm(device, stage, shaders, num_shaders, NULL, keep_executable_info, keep_statistic_info); /* Dump NIR after nir_to_asm, because ACO modifies it. */ - char *nir_string = NULL; - if (keep_executable_info || dump_shader) - nir_string = radv_dump_nir_shaders(instance, shaders, num_shaders); + if (keep_executable_info || debug->dump_shader) + debug->nir_string = radv_dump_nir_shaders(instance, shaders, num_shaders); + + radv_parse_binary_debug_info(device, *binary, debug); + debug->stages = 1 << shaders[0]->info.stage; + + radv_shader_dump_asm(device, debug, &stage->info); + + if (keep_executable_info && stage->spirv.size) { + debug->spirv = malloc(stage->spirv.size); + memcpy(debug->spirv, stage->spirv.data, stage->spirv.size); + debug->spirv_size = stage->spirv.size; + } + + if (debug->dump_shader) + simple_mtx_unlock(&instance->shader_dump_mtx); + + ralloc_free(mem_ctx); +} + +static VkResult +radv_rt_compile_nir(struct radv_device *device, struct vk_pipeline_cache *cache, + struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode, + struct radv_shader_stage *stage, uint32_t *payload_size, uint32_t *hit_attrib_size, + uint32_t *stack_size, struct radv_ray_tracing_stage_info *stage_info, + const struct radv_ray_tracing_stage_info *traversal_stage_info, + struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache, + bool has_position_fetch, struct radv_shader **out_shader) +{ + bool replayable = (pipeline->base.base.create_flags & + VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) && + !radv_is_traversal_shader(stage->nir); + + struct radv_shader_binary *binary; + struct radv_shader_debug_info debug = {}; + radv_rt_nir_to_asm(device, pipeline, mode, stage, payload_size, hit_attrib_size, stage_info, traversal_stage_info, + has_position_fetch, &binary, &debug); struct radv_shader *shader; if (replay_block || replayable) { VkResult result = radv_shader_create_uncached(device, binary, replayable, replay_block, &shader); if (result != VK_SUCCESS) { - if (dump_shader) - simple_mtx_unlock(&instance->shader_dump_mtx); - free(binary); return result; } - } else - shader = radv_shader_create(device, cache, binary, skip_shaders_cache || dump_shader, NULL); - - if (shader) { - radv_parse_binary_debug_info(device, binary, &shader->dbg); - shader->dbg.nir_string = nir_string; - shader->dbg.stages = 1 << shaders[0]->info.stage; - shader->dbg.dump_shader = dump_shader; - - if (stack_size) - *stack_size = DIV_ROUND_UP(shader->config.scratch_bytes_per_wave, shader->info.wave_size); - - radv_shader_dump_asm(device, &shader->dbg, &stage->info); - - if (shader && keep_executable_info && stage->spirv.size) { - shader->dbg.spirv = malloc(stage->spirv.size); - memcpy(shader->dbg.spirv, stage->spirv.data, stage->spirv.size); - shader->dbg.spirv_size = stage->spirv.size; - } + shader->dbg = debug; + } else { + shader = radv_shader_create(device, cache, binary, skip_shaders_cache, &debug); } - if (dump_shader) - simple_mtx_unlock(&instance->shader_dump_mtx); + if (shader && stack_size) + *stack_size = DIV_ROUND_UP(shader->config.scratch_bytes_per_wave, shader->info.wave_size); - ralloc_free(mem_ctx); free(binary); *out_shader = shader; @@ -808,9 +814,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca enum radv_rt_lowering_mode mode = stage->stage == MESA_SHADER_RAYGEN ? raygen_lowering_mode : recursive_lowering_mode; - result = radv_rt_nir_to_asm(device, cache, pipeline, mode, stage, &payload_size, &hit_attrib_size, &stack_size, - &rt_stages[idx].info, NULL, replay_block, skip_shaders_cache, has_position_fetch, - &rt_stages[idx].shader); + result = radv_rt_compile_nir(device, cache, pipeline, mode, stage, &payload_size, &hit_attrib_size, + &stack_size, &rt_stages[idx].info, NULL, replay_block, skip_shaders_cache, + has_position_fetch, &rt_stages[idx].shader); if (result != VK_SUCCESS) goto cleanup; @@ -864,9 +870,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca struct radv_serialized_shader_arena_block *replay_block = capture_replay_handles[idx].arena_va ? &capture_replay_handles[idx] : NULL; - result = radv_rt_nir_to_asm(device, cache, pipeline, RADV_RT_LOWERING_MODE_FUNCTION_CALLS, &combined_stage, - &payload_size, &hit_attrib_size, &stack_size, NULL, NULL, replay_block, - skip_shaders_cache, has_position_fetch, &pipeline->groups[idx].ahit_isec_shader); + result = radv_rt_compile_nir(device, cache, pipeline, RADV_RT_LOWERING_MODE_FUNCTION_CALLS, &combined_stage, + &payload_size, &hit_attrib_size, &stack_size, NULL, NULL, replay_block, + skip_shaders_cache, has_position_fetch, &pipeline->groups[idx].ahit_isec_shader); if (result != VK_SUCCESS) goto cleanup; @@ -927,10 +933,10 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca .key = stage_keys[MESA_SHADER_INTERSECTION], }; radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout); - result = radv_rt_nir_to_asm(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, &payload_size, - &hit_attrib_size, &pipeline->traversal_stack_size, NULL, &traversal_info, NULL, - skip_shaders_cache, has_position_fetch, - &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]); + result = radv_rt_compile_nir(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, &payload_size, + &hit_attrib_size, &pipeline->traversal_stack_size, NULL, &traversal_info, NULL, + skip_shaders_cache, has_position_fetch, + &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]); ralloc_free(traversal_nir); cleanup: