diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index df337fd0746..b16da5c4355 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -273,7 +273,7 @@ load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_def *idx, en /* This lowers all the RT instructions that we do not want to pass on to the combined shader and * that we can implement using the variables from the shader we are going to inline into. */ static void -lower_rt_instructions(nir_shader *shader, struct rt_variables *vars) +lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool apply_stack_ptr) { nir_builder b_shader = nir_builder_create(nir_shader_get_entrypoint(shader)); @@ -350,13 +350,15 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars) break; } case nir_intrinsic_load_scratch: { - nir_src_rewrite(&intr->src[0], - nir_iadd_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa)); + if (apply_stack_ptr) + nir_src_rewrite(&intr->src[0], + nir_iadd_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa)); continue; } case nir_intrinsic_store_scratch: { - nir_src_rewrite(&intr->src[1], - nir_iadd_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa)); + if (apply_stack_ptr) + nir_src_rewrite(&intr->src[1], + nir_iadd_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa)); continue; } case nir_intrinsic_load_rt_arg_scratch_offset_amd: { @@ -780,7 +782,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni struct rt_variables src_vars = create_rt_variables(shader, vars->flags); map_rt_variables(var_remap, &src_vars, vars); - NIR_PASS_V(shader, lower_rt_instructions, &src_vars); + NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false); NIR_PASS(_, shader, nir_lower_returns); NIR_PASS(_, shader, nir_opt_dce); @@ -1522,7 +1524,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH const VkPipelineCreateFlagBits2KHR create_flags = radv_get_pipeline_create_flags(pCreateInfo); struct rt_variables vars = create_rt_variables(shader, create_flags); - lower_rt_instructions(shader, &vars); + lower_rt_instructions(shader, &vars, true); if (stack_size) { vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);