spirv: fix cooperative matrix by value function params
Some checks are pending
macOS-CI / macOS-CI (dri) (push) Waiting to run
macOS-CI / macOS-CI (xlib) (push) Waiting to run

The vtn_ssa_value for a cmat is not backed by a nir_def, but by a nir_variable, so
can't be used directly when calling a function.  In most cases the cmat is used by
reference so code will take the value of deref for it (which is a `nir_def`).

When passing a cooperative matrix to a function by value, let the caller pass the deref
value, and the callee copy to a new local variable from that deref.

Fixes: b98f87612b ("spirv: Implement SPV_KHR_cooperative_matrix")
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34364>
This commit is contained in:
Georg Lehmann 2025-04-03 16:12:56 +02:00 committed by Marge Bot
parent 864ae91392
commit 0cad7b0968

View file

@ -78,7 +78,10 @@ vtn_ssa_value_add_to_call_params(struct vtn_builder *b,
nir_call_instr *call,
unsigned *param_idx)
{
if (glsl_type_is_vector_or_scalar(value->type)) {
if (glsl_type_is_cmat(value->type)) {
nir_deref_instr *src_deref = vtn_get_deref_for_ssa_value(b, value);
call->params[(*param_idx)++] = nir_src_for_ssa(&src_deref->def);
} else if (glsl_type_is_vector_or_scalar(value->type)) {
call->params[(*param_idx)++] = nir_src_for_ssa(value->def);
} else {
unsigned elems = glsl_get_length(value->type);
@ -150,7 +153,17 @@ vtn_ssa_value_load_function_param(struct vtn_builder *b,
struct vtn_func_arg_info *info,
unsigned *param_idx)
{
if (glsl_type_is_vector_or_scalar(value->type)) {
if (glsl_type_is_cmat(value->type)) {
nir_variable *copy_var =
nir_local_variable_create(b->nb.impl, value->type, "cmat_param_by_value");
nir_def *param = nir_load_param(&b->nb, (*param_idx)++);
nir_deref_instr *copy = nir_build_deref_var(&b->nb, copy_var);
nir_cmat_copy(&b->nb, &copy->def, param);
value->is_variable = true;
value->var = copy_var;
} else if (glsl_type_is_vector_or_scalar(value->type)) {
/* if the parameter is passed by value, we need to create a local copy if it's a pointer */
if (info->by_value && type && type->base_type == vtn_base_type_pointer) {
struct vtn_type *pointee_type = type->pointed;