radv/nir/lower_cmat: use cmat_mul instead of duplicating hw details for type conversion

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34382>
This commit is contained in:
Georg Lehmann 2025-04-04 13:47:31 +02:00 committed by Marge Bot
parent 31a3430570
commit bbc9bc9d24

View file

@ -445,25 +445,29 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_get_nir_type_for_glsl_base_type(dst_element_type),
nir_rounding_mode_undef);
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 16 && radv_nir_cmat_bits(dst_desc) == 32 &&
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
unsigned dst_mul = radv_nir_cmat_length_mul(dst_desc, &params);
unsigned src_mul = radv_nir_cmat_length_mul(src_desc, &params);
if (src_mul > dst_mul) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
components[i] = nir_channel(&b, src, i * 2);
unsigned scale = src_mul / dst_mul;
for (unsigned i = 0; i * scale < src->num_components; ++i) {
components[i] = nir_channel(&b, src, i * scale);
}
src = nir_vec(&b, components, src->num_components / 2);
src = nir_vec(&b, components, src->num_components / scale);
}
nir_def *ret = nir_build_alu1(&b, op, src);
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 32 && radv_nir_cmat_bits(dst_desc) == 16 &&
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
if (dst_mul > src_mul) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
unsigned scale = dst_mul / src_mul;
for (unsigned i = 0; i < ret->num_components; ++i) {
components[i * 2] = nir_channel(&b, ret, i);
components[i * 2 + 1] = nir_undef(&b, 1, 16);
components[i * scale] = nir_channel(&b, ret, i);
for (unsigned j = 1; j < scale; j++)
components[i * scale + j] = nir_undef(&b, 1, ret->bit_size);
}
ret = nir_vec(&b, components, ret->num_components * 2);
ret = nir_vec(&b, components, ret->num_components * scale);
}
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));