diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index c1b0042cc5d..a1bc4aacb01 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -1684,6 +1684,47 @@ radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR return false; } +static bool +should_move_rt_instruction(nir_intrinsic_op intrinsic) +{ + switch (intrinsic) { + case nir_intrinsic_load_rt_arg_scratch_offset_amd: + case nir_intrinsic_load_ray_flags: + case nir_intrinsic_load_ray_object_origin: + case nir_intrinsic_load_ray_world_origin: + case nir_intrinsic_load_ray_t_min: + case nir_intrinsic_load_ray_object_direction: + case nir_intrinsic_load_ray_world_direction: + case nir_intrinsic_load_ray_t_max: + return true; + default: + return false; + } +} + +static void +move_rt_instructions(nir_shader *shader) +{ + nir_cursor target = nir_before_cf_list(&nir_shader_get_entrypoint(shader)->body); + + nir_foreach_block (block, nir_shader_get_entrypoint(shader)) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr); + + if (!should_move_rt_instruction(intrinsic->intrinsic)) + continue; + + nir_instr_move(target, instr); + } + } + + nir_metadata_preserve(nir_shader_get_entrypoint(shader), + nir_metadata_all & (~nir_metadata_instr_index)); +} + static nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes) @@ -1743,6 +1784,11 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; nir_shader *nir_stage = parse_rt_stage(device, stage); + /* Move ray tracing system values to the top that are set by rt_trace_ray + * to prevent them from being overwritten by other rt_trace_ray calls. + */ + NIR_PASS_V(nir_stage, move_rt_instructions); + uint32_t num_resume_shaders = 0; nir_shader **resume_shaders = NULL; nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders,