From a57753dca21b456c2f55dd0047d293aef096edf4 Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Mon, 18 Aug 2025 11:30:41 +1000 Subject: [PATCH] nir: add coopmat per element operations. Cooperative matrix as per-element calls that are var args from a spir-v. These uses the new call op enum. Reviewed-by: Georg Lehmann Part-of: --- src/compiler/nir/nir.c | 2 + src/compiler/nir/nir.h | 5 ++ .../nir/nir_lower_cooperative_matrix.c | 47 +++++++++++++++++++ src/compiler/nir/nir_print.c | 2 + 4 files changed, 56 insertions(+) diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index c62f63c1421..c681dc1c61e 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -913,6 +913,8 @@ int nir_cmat_call_op_params(nir_cmat_call_op op, nir_function *callee) { switch (op) { + case nir_cmat_call_op_per_element_op: + return callee->num_params; case nir_cmat_call_op_reduce: return 2; case nir_cmat_call_op_reduce_finish: diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 6013fe77b38..95754d64e78 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1859,6 +1859,11 @@ typedef enum { * reduce 2x2 dst, src0, src1, src2, src3. */ nir_cmat_call_op_reduce_2x2, + /* + * Cooperative matrix per-element operation call + * per-element dst, row offset, col offset, src + */ + nir_cmat_call_op_per_element_op, } nir_cmat_call_op; typedef struct nir_cmat_call_instr { diff --git a/src/compiler/nir/nir_lower_cooperative_matrix.c b/src/compiler/nir/nir_lower_cooperative_matrix.c index 990bcff60fe..99d9eabefa5 100644 --- a/src/compiler/nir/nir_lower_cooperative_matrix.c +++ b/src/compiler/nir/nir_lower_cooperative_matrix.c @@ -724,6 +724,50 @@ split_cmat_load_store(nir_builder *b, return true; } +static bool +split_cmat_call_per_element_op(nir_builder *b, + nir_cmat_call_instr *call, + struct split_info *info) +{ + nir_instr *instr = &call->instr; + struct split_mat *dst_split = find_call_split(info->split_mats, call, 0); + struct split_mat *src_split = find_call_split(info->split_mats, call, 3); + if (!dst_split) + return false; + + assert(src_split); + int splits = dst_split->num_col_splits * dst_split->num_row_splits; + if (splits <= 1) + return false; + + for (unsigned r = 0; r < dst_split->num_row_splits; r++) { + for (unsigned c = 0; c < dst_split->num_col_splits; c++) { + int idx = r * dst_split->num_col_splits + c; + nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[idx]); + nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[idx]); + struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type); + nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee); + new_call->params[0] = nir_src_for_ssa(&dst_deref->def); + new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r)); + new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c)); + new_call->params[3] = nir_src_for_ssa(&src_deref->def); + + for (unsigned i = 4; i < call->num_params; i++) { + if (nir_src_as_deref(call->params[i])) { + struct split_mat *src1_split = find_call_split(info->split_mats, call, i); + nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[i], src1_split->split_vars[idx]) : nir_src_as_deref(call->params[i]); + new_call->params[i] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[i]; + } else + new_call->params[i] = call->params[i]; + } + b->cursor = nir_before_instr(instr); + nir_builder_instr_insert(b, &new_call->instr); + } + } + nir_instr_remove(instr); + return true; +} + static bool split_matrix_impl(nir_function_impl *impl, struct split_info *info) { @@ -787,6 +831,9 @@ split_matrix_impl(nir_function_impl *impl, struct split_info *info) case nir_cmat_call_op_reduce: progress |= split_cmat_call_reduce(&b, impl, cmat_call, info); break; + case nir_cmat_call_op_per_element_op: + progress |= split_cmat_call_per_element_op(&b, cmat_call, info); + break; default: break; } diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index cdfbdecbbc4..da490c4508b 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -2077,6 +2077,8 @@ get_cmat_call_op_str(nir_cmat_call_op op) return "cmat_call_reduce_finish"; case nir_cmat_call_op_reduce_2x2: return "cmat_call_reduce_2x2"; + case nir_cmat_call_op_per_element_op: + return "cmat_call_per_element"; } UNREACHABLE("Unknown cmat call op"); }