From bbc9bc9d245ce2394cc8b74f2dec1fdfbbaa67f5 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Fri, 4 Apr 2025 13:47:31 +0200 Subject: [PATCH] radv/nir/lower_cmat: use cmat_mul instead of duplicating hw details for type conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewed-by: Timur Kristóf Part-of: --- .../nir/radv_nir_lower_cooperative_matrix.c | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index ad89b30e0b3..7c4cfe5bc95 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -445,25 +445,29 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev nir_get_nir_type_for_glsl_base_type(dst_element_type), nir_rounding_mode_undef); - if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 16 && radv_nir_cmat_bits(dst_desc) == 32 && - dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) { + unsigned dst_mul = radv_nir_cmat_length_mul(dst_desc, ¶ms); + unsigned src_mul = radv_nir_cmat_length_mul(src_desc, ¶ms); + + if (src_mul > dst_mul) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; - for (unsigned i = 0; i * 2 < src->num_components; ++i) { - components[i] = nir_channel(&b, src, i * 2); + unsigned scale = src_mul / dst_mul; + for (unsigned i = 0; i * scale < src->num_components; ++i) { + components[i] = nir_channel(&b, src, i * scale); } - src = nir_vec(&b, components, src->num_components / 2); + src = nir_vec(&b, components, src->num_components / scale); } nir_def *ret = nir_build_alu1(&b, op, src); - if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 32 && radv_nir_cmat_bits(dst_desc) == 16 && - dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) { + if (dst_mul > src_mul) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; + unsigned scale = dst_mul / src_mul; for (unsigned i = 0; i < ret->num_components; ++i) { - components[i * 2] = nir_channel(&b, ret, i); - components[i * 2 + 1] = nir_undef(&b, 1, 16); + components[i * scale] = nir_channel(&b, ret, i); + for (unsigned j = 1; j < scale; j++) + components[i * scale + j] = nir_undef(&b, 1, ret->bit_size); } - ret = nir_vec(&b, components, ret->num_components * 2); + ret = nir_vec(&b, components, ret->num_components * scale); } nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));