From 6d2190300aa0147919ca08b02ace568d539c11eb Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Fri, 4 Apr 2025 13:53:54 +0200 Subject: [PATCH] radv/nir/lower_cmat: tightly pack 8bit gfx11 acc matrix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Invalid for now, but used by vkd3d-proton, where the use case is to convert a result matrix to lower precision, followed by a store. For 16bit accumulation matrices, GFX11 only uses 16bits per 32bit register. RADV's coop matrix code pads the unused space with undefs and uses a vector with twice as many elements as the matrix length. Extending that to 8bit by leaving 24 bits unused is unnecessary as these matrices as there is no hw unit that requires it. And in wave32, it would also result in vectors larger than NIR's limit. So tightly pack 8bit matrices without any undef padding. Reviewed-by: Timur Kristóf Part-of: --- .../vulkan/nir/radv_nir_lower_cooperative_matrix.c | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 7c4cfe5bc95..2e9ee078679 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -72,25 +72,25 @@ radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params if (params->gfx_level >= GFX12) { assert(desc.cols == 16 && desc.rows == 16); return 256 / params->wave_size; + } else if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) { + return 16; } else { - return desc.use != GLSL_CMAT_USE_ACCUMULATOR - ? 16 - : (desc.cols * desc.rows / params->wave_size * 32 / radv_nir_cmat_bits(desc)); + return desc.cols * desc.rows / params->wave_size * (radv_nir_cmat_bits(desc) == 16 ? 2 : 1); } } static unsigned radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params) { - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX12 || desc.use != GLSL_CMAT_USE_ACCUMULATOR) { return 1; } else { - /* For C matrices we have 1 VGPR per element even if the element type is - * < 32 bits. So with 8 fp16 elements we implement that with a f16vec16. + /* For GFX11 C matrices we have 1 VGPR per element even if the element type is + * 16bits. So with 8 fp16 elements we implement that with a f16vec16. * We then use the coefficient generated by this function to figure out * how many elements we really have. */ - return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / radv_nir_cmat_bits(desc)) : 1; + return radv_nir_cmat_bits(desc) == 16 ? 2 : 1; } }