radv/rt: Fix some tail-call compatibility checks

There were two issues here:
1. Tail calls where the tail-callee receives modified parameters are
hazardous and only work if the parameter is return or discardable.
Otherwise, the caller of the function that executes the tail-call may
not expect some of the parameters to be clobbered.
2. There was also an indexing confusion with the call instruction vs.
call signature parameters. The call instruction has not been adapted
to the new lowered signatures, where the system args are prepended. To
make things clearer, split the loop into two, one iterating over
parameters in the call signature and one for parameters of the call
instruction.

Cc: mesa-stable
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39579>
(cherry picked from commit 0d7705c206)
This commit is contained in:
Natalie Vock 2026-01-28 12:10:17 +01:00 committed by Eric Engestrom
parent 0753012766
commit 8ab3b18cd7
2 changed files with 34 additions and 11 deletions

View file

@ -1624,7 +1624,7 @@
"description": "radv/rt: Fix some tail-call compatibility checks",
"nominated": true,
"nomination_type": 1,
"resolution": 0,
"resolution": 1,
"main_sha": null,
"because_sha": null,
"notes": null

View file

@ -114,11 +114,32 @@ gather_tail_call_instrs_block(nir_function *caller, const struct nir_block *bloc
if (call->callee->num_params != caller->num_params)
return;
for (unsigned i = 0; i < call->num_params; ++i) {
for (unsigned i = 0; i < call->callee->num_params; ++i) {
if (call->callee->params[i].is_return != caller->params[i].is_return)
return;
if ((call->callee->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE) &&
!(caller->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE))
return;
bool has_preserved_regs =
(caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == ACO_NIR_CALL_ABI_AHIT_ISEC;
if (has_preserved_regs && ((call->callee->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE) !=
(caller->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE)))
return;
if (call->callee->params[i].is_uniform != caller->params[i].is_uniform)
return;
if (call->callee->params[i].bit_size != caller->params[i].bit_size)
return;
if (call->callee->params[i].num_components != caller->params[i].num_components)
return;
}
/* The call instruction itself has not been lowered to the new signature yet, so do this in a separate loop and
* adjust parameter indices for the caller.
*/
for (unsigned i = 0; i < call->num_params; ++i) {
unsigned caller_param_idx = i + ACO_NIR_CALL_SYSTEM_ARG_COUNT;
/* We can only do tail calls if the caller returns exactly the callee return values */
if (caller->params[i].is_return) {
if (caller->params[caller_param_idx].is_return) {
assert(nir_def_as_deref_or_null(call->params[i].ssa));
nir_deref_instr *deref_root = nir_def_as_deref(call->params[i].ssa);
while (nir_deref_instr_parent(deref_root))
@ -129,16 +150,18 @@ gather_tail_call_instrs_block(nir_function *caller, const struct nir_block *bloc
nir_intrinsic_instr *intrin = nir_def_as_intrinsic_or_null(deref_root->parent.ssa);
if (!intrin || intrin->intrinsic != nir_intrinsic_load_param)
return;
/* The call parameters aren't lowered at this point, we need to add the call arg count here */
if (nir_intrinsic_param_idx(intrin) != i + ACO_NIR_CALL_SYSTEM_ARG_COUNT)
if (nir_intrinsic_param_idx(intrin) != caller_param_idx)
return;
} else if (!(caller->params[caller_param_idx].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE)) {
/* If the parameter is not marked as discardable, then we have to preserve the caller's value. Passing
* a modified value to a tail call leaves us unable to restore the original value, so bail out if we have
* modified parameters.
*/
nir_intrinsic_instr *intrin = nir_def_as_intrinsic_or_null(call->params[i].ssa);
if (!intrin || intrin->intrinsic != nir_intrinsic_load_param ||
nir_intrinsic_param_idx(intrin) != caller_param_idx)
return;
}
if (call->callee->params[i].is_uniform != caller->params[i].is_uniform)
return;
if (call->callee->params[i].bit_size != caller->params[i].bit_size)
return;
if (call->callee->params[i].num_components != caller->params[i].num_components)
return;
}
_mesa_set_add(tail_calls, instr);