From 3eef0c02453b59ed6380338fde8edd3063dd4227 Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Mon, 18 Aug 2025 11:32:40 +1000 Subject: [PATCH] radv: add support for cooperative matrix per element operations. Reviewed-by: Georg Lehmann Part-of: --- .../nir/radv_nir_lower_cooperative_matrix.c | 75 +++++++++++++++++++ src/amd/vulkan/radv_physical_device.c | 1 + 2 files changed, 76 insertions(+) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 7dc242b3ce6..7998078e43d 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -983,6 +983,78 @@ lower_cmat_deref(nir_deref_instr *deref, struct hash_table *type_map, const lowe return false; } +static bool +lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower_cmat_params *params) +{ + nir_def *src = radv_nir_load_cmat(b, params, call->params[3].ssa); + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_function *fnptr = call->callee; + + struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); + + nir_variable *elem_tmp = nir_local_variable_create(b->impl, glsl_get_cmat_element(dst_deref->type), "elemtmp"); + nir_deref_instr *elem_deref = nir_build_deref_var(b, elem_tmp); + nir_def *local_idx = nir_load_subgroup_invocation(b); + nir_def *inner_idx = nir_iand_imm(b, local_idx, 15); + unsigned length = radv_nir_cmat_length(desc, params); + unsigned mul = radv_nir_cmat_length_mul(desc, params); + unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params->wave_size : 16; + nir_def *base_row = radv_get_base_row(b, desc, params, local_idx); + nir_def *vars[16]; + + if (mul > 1) { + for (unsigned i = 0; i < length; ++i) + if (i % mul != 0) + vars[i] = nir_undef(b, 1, radv_nir_cmat_bits(desc)); + } + + for (unsigned i = 0; i < length / mul; i++) { + nir_def *src_elem = nir_channel(b, src, i * mul); + nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr); + uint32_t row_iter; + + if (params->gfx_level >= GFX12) { + row_iter = i; + } else { + row_iter = i * lanes_per_iter / 16; + } + + nir_def *row_val = nir_iadd_imm(b, base_row, row_iter); + nir_def *col_val = inner_idx; + + if (desc.use == GLSL_CMAT_USE_A) + SWAP(col_val, row_val); + + row_val = nir_iadd(b, call->params[1].ssa, row_val); + col_val = nir_iadd(b, call->params[2].ssa, col_val); + + new_call->params[0] = nir_src_for_ssa(&elem_deref->def); + new_call->params[1] = nir_src_for_ssa(row_val); + new_call->params[2] = nir_src_for_ssa(col_val); + new_call->params[3] = nir_src_for_ssa(src_elem); + + for (unsigned p = 4; p < call->num_params; p++) { + nir_deref_instr *deref = nir_src_as_deref(call->params[p]); + nir_def *def = call->params[p].ssa; + if (deref) { + if (glsl_type_is_cmat(deref->type)) { + def = nir_build_load_deref(b, radv_nir_cmat_length(desc, params), radv_nir_cmat_bits(desc), def); + def = nir_channel(b, def, i * mul); + } + } + new_call->params[p] = nir_src_for_ssa(def); + } + nir_builder_instr_insert(b, &new_call->instr); + vars[i * mul] = nir_build_load_deref(b, 1, radv_nir_cmat_bits(desc), &elem_deref->def, 0); + } + + nir_def *mat = nir_vec(b, vars, length); + nir_store_deref(b, dst_deref, mat, nir_component_mask(src->num_components)); + + nir_instr_remove(&call->instr); + return true; +} + bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size) { @@ -1084,6 +1156,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev case nir_cmat_call_op_reduce_2x2: progress |= lower_cmat_reduce_2x2_call(&b, call, ¶ms); break; + case nir_cmat_call_op_per_element_op: + progress |= lower_cmat_per_element_op(&b, call, ¶ms); + break; default: break; } diff --git a/src/amd/vulkan/radv_physical_device.c b/src/amd/vulkan/radv_physical_device.c index b1bb946fdf4..256ec797392 100644 --- a/src/amd/vulkan/radv_physical_device.c +++ b/src/amd/vulkan/radv_physical_device.c @@ -1403,6 +1403,7 @@ radv_physical_device_get_features(const struct radv_physical_device *pdev, struc .cooperativeMatrixConversions = true, .cooperativeMatrixFlexibleDimensions = true, .cooperativeMatrixReductions = true, + .cooperativeMatrixPerElementOperations = true, /* VK_KHR_video_encode_av1 */ .videoEncodeAV1 = true,