radv: move radv_shader_create out of radv_rt_nir_to_asm

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40627>
This commit is contained in:
Rhys Perry 2026-03-25 10:38:14 +00:00 committed by Marge Bot
parent 2260105ba1
commit 4c3a74bebe

View file

@ -370,19 +370,16 @@ move_rt_instructions(nir_shader *shader)
return nir_progress(progress, nir_shader_get_entrypoint(shader), nir_metadata_control_flow);
}
static VkResult
radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode,
struct radv_shader_stage *stage, uint32_t *payload_size, uint32_t *hit_attrib_size,
uint32_t *stack_size, struct radv_ray_tracing_stage_info *stage_info,
const struct radv_ray_tracing_stage_info *traversal_stage_info,
struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache,
bool has_position_fetch, struct radv_shader **out_shader)
static void
radv_rt_nir_to_asm(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
enum radv_rt_lowering_mode mode, struct radv_shader_stage *stage, uint32_t *payload_size,
uint32_t *hit_attrib_size, struct radv_ray_tracing_stage_info *stage_info,
const struct radv_ray_tracing_stage_info *traversal_stage_info, bool has_position_fetch,
struct radv_shader_binary **binary, struct radv_shader_debug_info *debug)
{
struct radv_physical_device *pdev = radv_device_physical(device);
struct radv_instance *instance = radv_physical_device_instance(pdev);
struct radv_shader_binary *binary;
bool keep_executable_info = radv_pipeline_capture_shaders(device, pipeline->base.base.create_flags);
bool keep_statistic_info = radv_pipeline_capture_shader_stats(device, pipeline->base.base.create_flags);
@ -436,8 +433,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
unsigned num_shaders = num_resume_shaders + 1;
nir_shader **shaders = ralloc_array(mem_ctx, nir_shader *, num_shaders);
if (!shaders)
return VK_ERROR_OUT_OF_HOST_MEMORY;
shaders[0] = stage->nir;
for (uint32_t i = 0; i < num_resume_shaders; i++)
@ -488,13 +483,10 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
radv_gather_unused_args(stage_info, temp_stage.nir);
}
bool dump_shader = radv_can_dump_shader(device, stage->nir);
bool dump_nir = dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR);
bool replayable = (pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) &&
!radv_is_traversal_shader(stage->nir);
debug->dump_shader = radv_can_dump_shader(device, stage->nir);
bool dump_nir = debug->dump_shader && (instance->debug_flags & RADV_DEBUG_DUMP_NIR);
if (dump_shader) {
if (debug->dump_shader) {
simple_mtx_lock(&instance->shader_dump_mtx);
if (dump_nir) {
@ -504,49 +496,63 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
}
/* Compile NIR shader to AMD assembly. */
binary =
*binary =
radv_shader_nir_to_asm(device, stage, shaders, num_shaders, NULL, keep_executable_info, keep_statistic_info);
/* Dump NIR after nir_to_asm, because ACO modifies it. */
char *nir_string = NULL;
if (keep_executable_info || dump_shader)
nir_string = radv_dump_nir_shaders(instance, shaders, num_shaders);
if (keep_executable_info || debug->dump_shader)
debug->nir_string = radv_dump_nir_shaders(instance, shaders, num_shaders);
radv_parse_binary_debug_info(device, *binary, debug);
debug->stages = 1 << shaders[0]->info.stage;
radv_shader_dump_asm(device, debug, &stage->info);
if (keep_executable_info && stage->spirv.size) {
debug->spirv = malloc(stage->spirv.size);
memcpy(debug->spirv, stage->spirv.data, stage->spirv.size);
debug->spirv_size = stage->spirv.size;
}
if (debug->dump_shader)
simple_mtx_unlock(&instance->shader_dump_mtx);
ralloc_free(mem_ctx);
}
static VkResult
radv_rt_compile_nir(struct radv_device *device, struct vk_pipeline_cache *cache,
struct radv_ray_tracing_pipeline *pipeline, enum radv_rt_lowering_mode mode,
struct radv_shader_stage *stage, uint32_t *payload_size, uint32_t *hit_attrib_size,
uint32_t *stack_size, struct radv_ray_tracing_stage_info *stage_info,
const struct radv_ray_tracing_stage_info *traversal_stage_info,
struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache,
bool has_position_fetch, struct radv_shader **out_shader)
{
bool replayable = (pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) &&
!radv_is_traversal_shader(stage->nir);
struct radv_shader_binary *binary;
struct radv_shader_debug_info debug = {};
radv_rt_nir_to_asm(device, pipeline, mode, stage, payload_size, hit_attrib_size, stage_info, traversal_stage_info,
has_position_fetch, &binary, &debug);
struct radv_shader *shader;
if (replay_block || replayable) {
VkResult result = radv_shader_create_uncached(device, binary, replayable, replay_block, &shader);
if (result != VK_SUCCESS) {
if (dump_shader)
simple_mtx_unlock(&instance->shader_dump_mtx);
free(binary);
return result;
}
} else
shader = radv_shader_create(device, cache, binary, skip_shaders_cache || dump_shader, NULL);
if (shader) {
radv_parse_binary_debug_info(device, binary, &shader->dbg);
shader->dbg.nir_string = nir_string;
shader->dbg.stages = 1 << shaders[0]->info.stage;
shader->dbg.dump_shader = dump_shader;
if (stack_size)
*stack_size = DIV_ROUND_UP(shader->config.scratch_bytes_per_wave, shader->info.wave_size);
radv_shader_dump_asm(device, &shader->dbg, &stage->info);
if (shader && keep_executable_info && stage->spirv.size) {
shader->dbg.spirv = malloc(stage->spirv.size);
memcpy(shader->dbg.spirv, stage->spirv.data, stage->spirv.size);
shader->dbg.spirv_size = stage->spirv.size;
}
shader->dbg = debug;
} else {
shader = radv_shader_create(device, cache, binary, skip_shaders_cache, &debug);
}
if (dump_shader)
simple_mtx_unlock(&instance->shader_dump_mtx);
if (shader && stack_size)
*stack_size = DIV_ROUND_UP(shader->config.scratch_bytes_per_wave, shader->info.wave_size);
ralloc_free(mem_ctx);
free(binary);
*out_shader = shader;
@ -808,9 +814,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
enum radv_rt_lowering_mode mode =
stage->stage == MESA_SHADER_RAYGEN ? raygen_lowering_mode : recursive_lowering_mode;
result = radv_rt_nir_to_asm(device, cache, pipeline, mode, stage, &payload_size, &hit_attrib_size, &stack_size,
&rt_stages[idx].info, NULL, replay_block, skip_shaders_cache, has_position_fetch,
&rt_stages[idx].shader);
result = radv_rt_compile_nir(device, cache, pipeline, mode, stage, &payload_size, &hit_attrib_size,
&stack_size, &rt_stages[idx].info, NULL, replay_block, skip_shaders_cache,
has_position_fetch, &rt_stages[idx].shader);
if (result != VK_SUCCESS)
goto cleanup;
@ -864,9 +870,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
struct radv_serialized_shader_arena_block *replay_block =
capture_replay_handles[idx].arena_va ? &capture_replay_handles[idx] : NULL;
result = radv_rt_nir_to_asm(device, cache, pipeline, RADV_RT_LOWERING_MODE_FUNCTION_CALLS, &combined_stage,
&payload_size, &hit_attrib_size, &stack_size, NULL, NULL, replay_block,
skip_shaders_cache, has_position_fetch, &pipeline->groups[idx].ahit_isec_shader);
result = radv_rt_compile_nir(device, cache, pipeline, RADV_RT_LOWERING_MODE_FUNCTION_CALLS, &combined_stage,
&payload_size, &hit_attrib_size, &stack_size, NULL, NULL, replay_block,
skip_shaders_cache, has_position_fetch, &pipeline->groups[idx].ahit_isec_shader);
if (result != VK_SUCCESS)
goto cleanup;
@ -927,10 +933,10 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
.key = stage_keys[MESA_SHADER_INTERSECTION],
};
radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout);
result = radv_rt_nir_to_asm(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, &payload_size,
&hit_attrib_size, &pipeline->traversal_stack_size, NULL, &traversal_info, NULL,
skip_shaders_cache, has_position_fetch,
&pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
result = radv_rt_compile_nir(device, cache, pipeline, recursive_lowering_mode, &traversal_stage, &payload_size,
&hit_attrib_size, &pipeline->traversal_stack_size, NULL, &traversal_info, NULL,
skip_shaders_cache, has_position_fetch,
&pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
ralloc_free(traversal_nir);
cleanup: