diff --git a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c index 5fe0b29d074..e4080ce59e6 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c +++ b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c @@ -144,6 +144,7 @@ radv_get_ray_query_type() struct ray_query_vars { nir_variable *var; + bool use_bvh_stack_rtn; bool shared_stack; uint32_t shared_base; uint32_t stack_entries; @@ -165,10 +166,18 @@ init_ray_query_vars(nir_shader *shader, const glsl_type *opaque_type, struct ray workgroup_size = MAX2(workgroup_size, 32); uint32_t shared_stack_size = workgroup_size * shared_stack_entries * 4; uint32_t shared_offset = align(shader->info.shared_size, 4); + if (shader->info.stage != MESA_SHADER_COMPUTE || glsl_type_is_array(opaque_type) || shared_offset + shared_stack_size > pdev->max_shared_size) { dst->stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT; } else { + if (radv_use_bvh_stack_rtn(pdev)) { + /* The hardware ds_bvh_stack_rtn address can only encode a stack base up to 8191 dwords. */ + uint32_t num_wave32_groups = DIV_ROUND_UP(workgroup_size, 32); + uint32_t max_group_stack_base = (num_wave32_groups - 1) * 32 * shared_stack_entries; + uint32_t max_stack_base = (shared_offset / 4) + max_group_stack_base; + dst->use_bvh_stack_rtn = max_stack_base < 8192; + } dst->shared_stack = true; dst->shared_base = shared_offset; dst->stack_entries = shared_stack_entries; @@ -303,7 +312,7 @@ lower_rq_initialize(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query if (vars->shared_stack) { nir_def *stack_idx = nir_load_local_invocation_index(b); - if (radv_use_bvh_stack_rtn(pdev)) { + if (vars->use_bvh_stack_rtn) { uint32_t workgroup_size = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2]; nir_def *addr = @@ -563,7 +572,7 @@ lower_rq_proceed(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query_va }; if (vars->shared_stack) { - args.use_bvh_stack_rtn = radv_use_bvh_stack_rtn(pdev); + args.use_bvh_stack_rtn = vars->use_bvh_stack_rtn; if (args.use_bvh_stack_rtn) { args.stack_stride = 1; args.stack_base = 0;