diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 7c4cfe5bc95..2e9ee078679 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -72,25 +72,25 @@ radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params if (params->gfx_level >= GFX12) { assert(desc.cols == 16 && desc.rows == 16); return 256 / params->wave_size; + } else if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) { + return 16; } else { - return desc.use != GLSL_CMAT_USE_ACCUMULATOR - ? 16 - : (desc.cols * desc.rows / params->wave_size * 32 / radv_nir_cmat_bits(desc)); + return desc.cols * desc.rows / params->wave_size * (radv_nir_cmat_bits(desc) == 16 ? 2 : 1); } } static unsigned radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params) { - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX12 || desc.use != GLSL_CMAT_USE_ACCUMULATOR) { return 1; } else { - /* For C matrices we have 1 VGPR per element even if the element type is - * < 32 bits. So with 8 fp16 elements we implement that with a f16vec16. + /* For GFX11 C matrices we have 1 VGPR per element even if the element type is + * 16bits. So with 8 fp16 elements we implement that with a f16vec16. * We then use the coefficient generated by this function to figure out * how many elements we really have. */ - return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / radv_nir_cmat_bits(desc)) : 1; + return radv_nir_cmat_bits(desc) == 16 ? 2 : 1; } }