diff --git a/.pick_status.json b/.pick_status.json index 0e25b04e3f4..eb941a6af46 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -1624,7 +1624,7 @@ "description": "radv/rt: Fix some tail-call compatibility checks", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/amd/vulkan/nir/radv_nir_lower_call_abi.c b/src/amd/vulkan/nir/radv_nir_lower_call_abi.c index 76e141472d8..60619de53fb 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_call_abi.c +++ b/src/amd/vulkan/nir/radv_nir_lower_call_abi.c @@ -114,11 +114,32 @@ gather_tail_call_instrs_block(nir_function *caller, const struct nir_block *bloc if (call->callee->num_params != caller->num_params) return; - for (unsigned i = 0; i < call->num_params; ++i) { + for (unsigned i = 0; i < call->callee->num_params; ++i) { if (call->callee->params[i].is_return != caller->params[i].is_return) return; + if ((call->callee->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE) && + !(caller->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE)) + return; + bool has_preserved_regs = + (caller->driver_attributes & ACO_NIR_FUNCTION_ATTRIB_ABI_MASK) == ACO_NIR_CALL_ABI_AHIT_ISEC; + if (has_preserved_regs && ((call->callee->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE) != + (caller->params[i].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE))) + return; + if (call->callee->params[i].is_uniform != caller->params[i].is_uniform) + return; + if (call->callee->params[i].bit_size != caller->params[i].bit_size) + return; + if (call->callee->params[i].num_components != caller->params[i].num_components) + return; + } + + /* The call instruction itself has not been lowered to the new signature yet, so do this in a separate loop and + * adjust parameter indices for the caller. + */ + for (unsigned i = 0; i < call->num_params; ++i) { + unsigned caller_param_idx = i + ACO_NIR_CALL_SYSTEM_ARG_COUNT; /* We can only do tail calls if the caller returns exactly the callee return values */ - if (caller->params[i].is_return) { + if (caller->params[caller_param_idx].is_return) { assert(nir_def_as_deref_or_null(call->params[i].ssa)); nir_deref_instr *deref_root = nir_def_as_deref(call->params[i].ssa); while (nir_deref_instr_parent(deref_root)) @@ -129,16 +150,18 @@ gather_tail_call_instrs_block(nir_function *caller, const struct nir_block *bloc nir_intrinsic_instr *intrin = nir_def_as_intrinsic_or_null(deref_root->parent.ssa); if (!intrin || intrin->intrinsic != nir_intrinsic_load_param) return; - /* The call parameters aren't lowered at this point, we need to add the call arg count here */ - if (nir_intrinsic_param_idx(intrin) != i + ACO_NIR_CALL_SYSTEM_ARG_COUNT) + if (nir_intrinsic_param_idx(intrin) != caller_param_idx) + return; + } else if (!(caller->params[caller_param_idx].driver_attributes & ACO_NIR_PARAM_ATTRIB_DISCARDABLE)) { + /* If the parameter is not marked as discardable, then we have to preserve the caller's value. Passing + * a modified value to a tail call leaves us unable to restore the original value, so bail out if we have + * modified parameters. + */ + nir_intrinsic_instr *intrin = nir_def_as_intrinsic_or_null(call->params[i].ssa); + if (!intrin || intrin->intrinsic != nir_intrinsic_load_param || + nir_intrinsic_param_idx(intrin) != caller_param_idx) return; } - if (call->callee->params[i].is_uniform != caller->params[i].is_uniform) - return; - if (call->callee->params[i].bit_size != caller->params[i].bit_size) - return; - if (call->callee->params[i].num_components != caller->params[i].num_components) - return; } _mesa_set_add(tail_calls, instr);