From 1ba49c3594864c8faa41644d0359ccbdb4ff8e13 Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Mon, 18 Aug 2025 11:32:19 +1000 Subject: [PATCH] spirv: add initial support for cooperative matrix per-element ops Reviewed-by: Georg Lehmann Part-of: --- src/compiler/spirv/spirv_to_nir.c | 2 ++ src/compiler/spirv/vtn_cmat.c | 35 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 15e8ccb6558..be4098001d8 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -74,6 +74,7 @@ static const struct spirv_capabilities implemented_capabilities = { .CooperativeMatrixKHR = true, .CooperativeMatrixConversionsNV = true, .CooperativeMatrixReductionsNV = true, + .CooperativeMatrixPerElementOperationsNV = true, .CoreBuiltinsARM = true, .CullDistance = true, .DemoteToHelperInvocation = true, @@ -7011,6 +7012,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpCooperativeMatrixConvertNV: case SpvOpCooperativeMatrixTransposeNV: case SpvOpCooperativeMatrixReduceNV: + case SpvOpCooperativeMatrixPerElementOpNV: vtn_handle_cooperative_instruction(b, opcode, w, count); break; diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c index 1390269d5ff..1b4de345327 100644 --- a/src/compiler/spirv/vtn_cmat.c +++ b/src/compiler/spirv/vtn_cmat.c @@ -238,6 +238,41 @@ vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpCooperativeMatrixPerElementOpNV: { + struct vtn_type *dst_type = vtn_get_type(b, w[1]); + nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); + + struct vtn_function *per_element_fn = vtn_value(b, w[4], vtn_value_type_function)->func; + + per_element_fn->referenced = true; + per_element_fn->nir_func->cmat_call = true; + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_per_element_nv"); + + nir_cmat_call_instr *call = nir_cmat_call_instr_create(b->nb.shader, nir_cmat_call_op_per_element_op, per_element_fn->nir_func); + call->params[0] = nir_src_for_ssa(&dst->def); + call->params[1] = nir_src_for_ssa(nir_imm_zero(&b->nb, 1, 32)); + call->params[2] = nir_src_for_ssa(nir_imm_zero(&b->nb, 1, 32)); + call->params[3] = nir_src_for_ssa(&src->def); + + for (unsigned i = 0; i < count - 5; i++) { + struct vtn_ssa_value *ssa = vtn_ssa_value(b, w[5 + i]); + nir_def *def; + nir_deref_instr *deref = NULL; + + if (ssa->is_variable) { + deref = nir_build_deref_var(&b->nb, ssa->var); + def = &deref->def; + } else if (glsl_type_is_vector_or_scalar(ssa->type)) { + def = ssa->def; + } else + def = ssa->elems[0]->def; + + call->params[4 + i] = nir_src_for_ssa(def); + } + nir_builder_instr_insert(&b->nb, &call->instr); + vtn_push_var_ssa(b, w[2], dst->var); + break; + } default: UNREACHABLE("Unexpected opcode for cooperative matrix instruction"); }