diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 23dd7100aa2..0cb765ef06d 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -1620,9 +1620,13 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf */ NIR_PASS_V(nir_stage, move_rt_instructions); + const nir_lower_shader_calls_options opts = { + .address_format = nir_address_format_32bit_offset, + .stack_alignment = 16, + }; uint32_t num_resume_shaders = 0; nir_shader **resume_shaders = NULL; - nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders, + nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, &num_resume_shaders, nir_stage); vars.stage_idx = i; diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 66c30471426..c77cd52a681 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4834,10 +4834,17 @@ bool nir_lower_explicit_io(nir_shader *shader, nir_variable_mode modes, nir_address_format); +typedef struct nir_lower_shader_calls_options { + /* Address format used for load/store operations on the call stack. */ + nir_address_format address_format; + + /* Stack alignment */ + unsigned stack_alignment; +} nir_lower_shader_calls_options; + bool nir_lower_shader_calls(nir_shader *shader, - nir_address_format address_format, - unsigned stack_alignment, + const nir_lower_shader_calls_options *options, nir_shader ***resume_shaders_out, uint32_t *num_resume_shaders_out, void *mem_ctx); diff --git a/src/compiler/nir/nir_lower_shader_calls.c b/src/compiler/nir/nir_lower_shader_calls.c index 8ec5a4b9a3d..e34c8c0b1ff 100644 --- a/src/compiler/nir/nir_lower_shader_calls.c +++ b/src/compiler/nir/nir_lower_shader_calls.c @@ -1424,8 +1424,7 @@ nir_opt_remove_respills(nir_shader *shader) */ bool nir_lower_shader_calls(nir_shader *shader, - nir_address_format address_format, - unsigned stack_alignment, + const nir_lower_shader_calls_options *options, nir_shader ***resume_shaders_out, uint32_t *num_resume_shaders_out, void *mem_ctx) @@ -1461,7 +1460,7 @@ nir_lower_shader_calls(nir_shader *shader, } NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls, - num_calls, stack_alignment); + num_calls, options->stack_alignment); NIR_PASS_V(shader, nir_opt_remove_phis); @@ -1494,9 +1493,12 @@ nir_lower_shader_calls(nir_shader *shader, for (unsigned i = 0; i < num_calls; i++) NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills); - NIR_PASS_V(shader, nir_lower_stack_to_scratch, address_format); - for (unsigned i = 0; i < num_calls; i++) - NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch, address_format); + NIR_PASS_V(shader, nir_lower_stack_to_scratch, + options->address_format); + for (unsigned i = 0; i < num_calls; i++) { + NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch, + options->address_format); + } *resume_shaders_out = resume_shaders; *num_resume_shaders_out = num_calls; diff --git a/src/intel/vulkan/anv_pipeline.c b/src/intel/vulkan/anv_pipeline.c index f2cdeb23150..b20efaad304 100644 --- a/src/intel/vulkan/anv_pipeline.c +++ b/src/intel/vulkan/anv_pipeline.c @@ -2465,9 +2465,12 @@ compile_upload_rt_shader(struct anv_ray_tracing_pipeline *pipeline, nir_shader **resume_shaders = NULL; uint32_t num_resume_shaders = 0; if (nir->info.stage != MESA_SHADER_COMPUTE) { - NIR_PASS(_, nir, nir_lower_shader_calls, - nir_address_format_64bit_global, - BRW_BTD_STACK_ALIGN, + const nir_lower_shader_calls_options opts = { + .address_format = nir_address_format_64bit_global, + .stack_alignment = BRW_BTD_STACK_ALIGN, + }; + + NIR_PASS(_, nir, nir_lower_shader_calls, &opts, &resume_shaders, &num_resume_shaders, mem_ctx); NIR_PASS(_, nir, brw_nir_lower_shader_calls, &stage->key.bs); NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);