diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 2fb585501c8..5e0f66a7c14 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -605,10 +605,14 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev src = nir_vec(&b, components, src->num_components / scale); } - src = convert_use(&b, src, src_use, dst_use, ¶ms); + if (radv_nir_cmat_bits(src_desc) <= radv_nir_cmat_bits(dst_desc)) + src = convert_use(&b, src, src_use, dst_use, ¶ms); nir_def *ret = convert_base_type(&b, src, src_element_type, dst_element_type, sat); + if (radv_nir_cmat_bits(src_desc) > radv_nir_cmat_bits(dst_desc)) + ret = convert_use(&b, ret, src_use, dst_use, ¶ms); + if (dst_mul > src_mul) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; unsigned scale = dst_mul / src_mul;