From d06aff224388ffbaa281c84db8f1922ec0e8bbf9 Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Tue, 18 Nov 2025 15:46:17 +0100 Subject: [PATCH] nak/cmat: use movm Sadly I don't see an obvious way to use it for int8 matrices, therefore the code is a bit of a mess right now. It allows us to vectorize load/stores more often as we can simply transpose row/col major matrices when needed. And the movm optimization is also only enabled for 16 bit types, even though we _could_ do it for 32 bit. It's not clear yet if using it for 32 bit types is an overall advantage or not. Reviewed-by: Mel Henning Part-of: --- src/nouveau/compiler/nak_nir_lower_cmat.c | 112 ++++++++++++++++++---- 1 file changed, 95 insertions(+), 17 deletions(-) diff --git a/src/nouveau/compiler/nak_nir_lower_cmat.c b/src/nouveau/compiler/nak_nir_lower_cmat.c index b71a4301824..49315f23bdb 100644 --- a/src/nouveau/compiler/nak_nir_lower_cmat.c +++ b/src/nouveau/compiler/nak_nir_lower_cmat.c @@ -232,6 +232,66 @@ remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig) return new_type; } +static bool +uses_movm_for_bit_size(unsigned bit_size) +{ + return bit_size == 16; +} + +/** + * Returns true when before stores or after loads the loaded matrix has to be transposed + */ +static bool +transpose_on_load_store(struct glsl_cmat_description desc, + enum glsl_matrix_layout layout) +{ + return + uses_movm_for_bit_size(glsl_base_type_get_bit_size(desc.element_type)) && + ((desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) || + (desc.use != GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR)); +} + +static nir_def * +transpose_matrix(nir_builder *b, nir_def *value) +{ + unsigned vec_size = value->num_components; + unsigned bit_size = value->bit_size; + + switch (bit_size) { + case 32: { + assert(vec_size == 2); + + nir_def *raw = nir_unpack_64_4x16(b, nir_pack_64_2x32(b, value)); + nir_def *lo = nir_vec2(b, + nir_channel(b, raw, 0), + nir_channel(b, raw, 2) + ); + nir_def *hi = nir_vec2(b, + nir_channel(b, raw, 1), + nir_channel(b, raw, 3) + ); + + lo = nir_cmat_mov_transpose_nv(b, lo); + hi = nir_cmat_mov_transpose_nv(b, hi); + + value = nir_vec2(b, + nir_pack_32_2x16(b, nir_vec2(b, nir_channel(b, lo, 0), nir_channel(b, hi, 0))), + nir_pack_32_2x16(b, nir_vec2(b, nir_channel(b, lo, 1), nir_channel(b, hi, 1))) + ); + break; + } + case 16: + assert(vec_size == 2); + value = nir_cmat_mov_transpose_nv(b, value); + break; + default: + assert(!"unsupported bit_size for transpose"); + break; + } + + return value; +} + /** * Computes the index in a linear matrix buffer a thread needs to load from in * order to execute an MMA on the Matrix. @@ -245,7 +305,7 @@ remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig) static void compute_mat(struct nir_builder *b, nir_def *lane_id, unsigned idx, nir_def **col, nir_def **row, - struct glsl_cmat_description desc, + bool alternate_tiling_order, unsigned group_size) { assert(idx < 4 * group_size); @@ -253,8 +313,8 @@ compute_mat(struct nir_builder *b, nir_def *lane_id, nir_def *quad_id = nir_ushr_imm(b, lane_id, 2); nir_def *thread_id_in_quad = nir_iand_imm(b, lane_id, 0x3); - unsigned row_bound = (desc.use == GLSL_CMAT_USE_B ? 2 : 1) * group_size; - unsigned col_bound = (desc.use == GLSL_CMAT_USE_B ? 1 : 2) * group_size; + unsigned row_bound = (alternate_tiling_order ? 2 : 1) * group_size; + unsigned col_bound = (alternate_tiling_order ? 1 : 2) * group_size; *row = quad_id; if (idx & row_bound) @@ -269,17 +329,17 @@ compute_mat(struct nir_builder *b, nir_def *lane_id, static void compute_mat_16x32_int8(struct nir_builder *b, nir_def *lane_id, unsigned idx, nir_def **col, nir_def **row, - struct glsl_cmat_description desc) + bool alternate_tiling_order) { - compute_mat(b, lane_id, idx, col, row, desc, 4); + compute_mat(b, lane_id, idx, col, row, alternate_tiling_order, 4); } static void compute_mat_16x16(struct nir_builder *b, nir_def *lane_id, unsigned idx, nir_def **col, nir_def **row, - struct glsl_cmat_description desc) + bool alternate_tiling_order) { - compute_mat(b, lane_id, idx, col, row, desc, 2); + compute_mat(b, lane_id, idx, col, row, alternate_tiling_order, 2); } static void @@ -288,24 +348,32 @@ compute_matrix_offsets(struct nir_builder *b, struct glsl_cmat_description desc, unsigned idx, nir_def **col_offset, nir_def **row_offset) { enum nak_matrix_type_layout cmat_type = determine_matrix_type(desc); + unsigned bit_size = glsl_base_type_bit_size(desc.element_type); + bool uses_movm = uses_movm_for_bit_size(bit_size); + bool alternate_tiling_order = + (uses_movm && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) || + (!uses_movm && desc.use == GLSL_CMAT_USE_B); + switch (cmat_type) { case NAK_MAT_16x32_INT8: - compute_mat_16x32_int8(b, lane_id, idx, col_offset, row_offset, desc); + compute_mat_16x32_int8(b, lane_id, idx, col_offset, row_offset, alternate_tiling_order); break; case NAK_MAT_16X16: - compute_mat_16x16(b, lane_id, idx, col_offset, row_offset, desc); + compute_mat_16x16(b, lane_id, idx, col_offset, row_offset, alternate_tiling_order); break; } /* The layout calculation code relies on col and row being swapped for B * row-major and non B col-major matrices. */ - if ((desc.use == GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) || - (desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR)) { - nir_def *tmp = *col_offset; - *col_offset = *row_offset; - *row_offset = tmp; + if (!uses_movm) { + if ((desc.use == GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) || + (desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR)) { + nir_def *tmp = *col_offset; + *col_offset = *row_offset; + *row_offset = tmp; + } } } @@ -678,11 +746,16 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr) static int load_store_get_vec_size(const struct glsl_cmat_description desc, enum glsl_matrix_layout layout) { - if ((desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) || - (desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR)) + unsigned bit_size = glsl_base_type_bit_size(desc.element_type); + bool uses_movm = uses_movm_for_bit_size(bit_size); + bool needs_transpose = + (desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) || + (desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR); + + if (needs_transpose && !uses_movm) return 1; - switch (glsl_base_type_bit_size(desc.element_type)) { + switch (bit_size) { case 16: case 32: return 2; @@ -783,6 +856,9 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr) 0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0); nir_def *value = nir_load_deref(b, iter_deref); + if (transpose_on_load_store(desc, layout)) + value = transpose_matrix(b, value); + for (int c = 0; c < vec_size; c++) vars[idx + c] = nir_channel(b, value, c); } @@ -856,6 +932,8 @@ lower_cmat_instr(nir_builder *b, &iter_deref->def, modes, vec_type, 0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0); nir_def *value = nir_vec(b, &vars[idx], vec_size); + if (transpose_on_load_store(desc, layout)) + value = transpose_matrix(b, value); nir_store_deref(b, iter_deref, value, -1); }