radv/rt: Fix shared ray query stack on top of application LDS

Since the stack pointer may wrap around the stack size in overflow
cases, traversal logic calculates the real stack pointer with
nir_umod_imm(b, stack, args->stack_entries * args->stack_stride).

For ray queries, "stack" was initialized to
"stack_base + local_invocation_idx * 4". This was completely broken, as
the umod would later delete the stack base completely and overwrite the
start of LDS, which belongs to the apps' shared memory.

Instead, add the stack base as a constant offset in the load/store_stack
callback. (This should also save 1 VALU per ray query)
Also, delete radv_ray_traversal_args::stack_base since it's unused now.

Cc: mesa-stable
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40420>
This commit is contained in:
Natalie Vock 2026-03-14 15:59:58 +01:00 committed by Marge Bot
parent 28dd08755c
commit b046eaf36d
4 changed files with 5 additions and 11 deletions

View file

@ -323,7 +323,6 @@ lower_rq_initialize(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query
rq_store(b, rq, trav_stack_low_watermark, addr);
} else {
nir_def *base_offset = nir_imul_imm(b, stack_idx, sizeof(uint32_t));
base_offset = nir_iadd_imm(b, base_offset, vars->shared_base);
rq_store(b, rq, trav_stack, base_offset);
rq_store(b, rq, trav_stack_low_watermark, base_offset);
}
@ -493,7 +492,7 @@ store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct r
struct traversal_data *data = args->data;
if (data->vars->shared_stack)
nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
nir_store_shared(b, value, index, .base = data->vars->shared_base, .align_mul = 4);
else
nir_store_deref(b, nir_build_deref_array(b, rq_deref(b, data->rq, stack), index), value, 0x1);
}
@ -504,7 +503,7 @@ load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal
struct traversal_data *data = args->data;
if (data->vars->shared_stack)
return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
return nir_load_shared(b, 1, 32, index, .base = data->vars->shared_base, .align_mul = 4);
else
return nir_load_deref(b, nir_build_deref_array(b, rq_deref(b, data->rq, stack), index));
}
@ -577,16 +576,13 @@ lower_rq_proceed(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query_va
args.use_bvh_stack_rtn = vars->use_bvh_stack_rtn;
if (args.use_bvh_stack_rtn) {
args.stack_stride = 1;
args.stack_base = 0;
} else {
uint32_t workgroup_size =
b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2];
args.stack_stride = workgroup_size * 4;
args.stack_base = vars->shared_base;
}
} else {
args.stack_stride = 1;
args.stack_base = 0;
}
rq_store(b, rq, break_flag, nir_imm_false(b));

View file

@ -878,7 +878,7 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc
/* Early exit if we never overflowed the stack, to avoid having to backtrack to
* the root for no reason. */
if (!args->use_bvh_stack_rtn) {
nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride));
nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride));
{
nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
nir_jump(b, nir_jump_break);
@ -1174,7 +1174,7 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
/* Early exit if we never overflowed the stack, to avoid having to backtrack to
* the root for no reason. */
if (!args->use_bvh_stack_rtn) {
nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride));
nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride));
{
nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
nir_jump(b, nir_jump_break);

View file

@ -135,10 +135,9 @@ struct radv_ray_traversal_args {
struct radv_ray_traversal_vars vars;
/* The increment/decrement used for radv_ray_traversal_vars::stack, and how many entries are
* available. stack_base is the base address of the stack. */
* available. */
uint32_t stack_stride;
uint32_t stack_entries;
uint32_t stack_base;
uint32_t set_flags;
uint32_t unset_flags;

View file

@ -1251,7 +1251,6 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin
.vars = trav_vars_args,
.stack_stride = stack_stride,
.stack_entries = MAX_STACK_ENTRY_COUNT,
.stack_base = 0,
.ignore_cull_mask = params->ignore_cull_mask,
.set_flags = info ? info->set_flags : 0,
.unset_flags = info ? info->unset_flags : 0,