nak/cmat: use movm

Sadly I don't see an obvious way to use it for int8 matrices, therefore
the code is a bit of a mess right now.

It allows us to vectorize load/stores more often as we can simply
transpose row/col major matrices when needed.

And the movm optimization is also only enabled for 16 bit types, even
though we _could_ do it for 32 bit. It's not clear yet if using it for 32
bit types is an overall advantage or not.

Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37998>
This commit is contained in:
Karol Herbst 2025-11-18 15:46:17 +01:00 committed by Marge Bot
parent 626c6b35f0
commit d06aff2243

View file

@ -232,6 +232,66 @@ remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig)
return new_type;
}
static bool
uses_movm_for_bit_size(unsigned bit_size)
{
return bit_size == 16;
}
/**
* Returns true when before stores or after loads the loaded matrix has to be transposed
*/
static bool
transpose_on_load_store(struct glsl_cmat_description desc,
enum glsl_matrix_layout layout)
{
return
uses_movm_for_bit_size(glsl_base_type_get_bit_size(desc.element_type)) &&
((desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) ||
(desc.use != GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR));
}
static nir_def *
transpose_matrix(nir_builder *b, nir_def *value)
{
unsigned vec_size = value->num_components;
unsigned bit_size = value->bit_size;
switch (bit_size) {
case 32: {
assert(vec_size == 2);
nir_def *raw = nir_unpack_64_4x16(b, nir_pack_64_2x32(b, value));
nir_def *lo = nir_vec2(b,
nir_channel(b, raw, 0),
nir_channel(b, raw, 2)
);
nir_def *hi = nir_vec2(b,
nir_channel(b, raw, 1),
nir_channel(b, raw, 3)
);
lo = nir_cmat_mov_transpose_nv(b, lo);
hi = nir_cmat_mov_transpose_nv(b, hi);
value = nir_vec2(b,
nir_pack_32_2x16(b, nir_vec2(b, nir_channel(b, lo, 0), nir_channel(b, hi, 0))),
nir_pack_32_2x16(b, nir_vec2(b, nir_channel(b, lo, 1), nir_channel(b, hi, 1)))
);
break;
}
case 16:
assert(vec_size == 2);
value = nir_cmat_mov_transpose_nv(b, value);
break;
default:
assert(!"unsupported bit_size for transpose");
break;
}
return value;
}
/**
* Computes the index in a linear matrix buffer a thread needs to load from in
* order to execute an MMA on the Matrix.
@ -245,7 +305,7 @@ remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig)
static void
compute_mat(struct nir_builder *b, nir_def *lane_id,
unsigned idx, nir_def **col, nir_def **row,
struct glsl_cmat_description desc,
bool alternate_tiling_order,
unsigned group_size)
{
assert(idx < 4 * group_size);
@ -253,8 +313,8 @@ compute_mat(struct nir_builder *b, nir_def *lane_id,
nir_def *quad_id = nir_ushr_imm(b, lane_id, 2);
nir_def *thread_id_in_quad = nir_iand_imm(b, lane_id, 0x3);
unsigned row_bound = (desc.use == GLSL_CMAT_USE_B ? 2 : 1) * group_size;
unsigned col_bound = (desc.use == GLSL_CMAT_USE_B ? 1 : 2) * group_size;
unsigned row_bound = (alternate_tiling_order ? 2 : 1) * group_size;
unsigned col_bound = (alternate_tiling_order ? 1 : 2) * group_size;
*row = quad_id;
if (idx & row_bound)
@ -269,17 +329,17 @@ compute_mat(struct nir_builder *b, nir_def *lane_id,
static void
compute_mat_16x32_int8(struct nir_builder *b, nir_def *lane_id,
unsigned idx, nir_def **col, nir_def **row,
struct glsl_cmat_description desc)
bool alternate_tiling_order)
{
compute_mat(b, lane_id, idx, col, row, desc, 4);
compute_mat(b, lane_id, idx, col, row, alternate_tiling_order, 4);
}
static void
compute_mat_16x16(struct nir_builder *b, nir_def *lane_id,
unsigned idx, nir_def **col, nir_def **row,
struct glsl_cmat_description desc)
bool alternate_tiling_order)
{
compute_mat(b, lane_id, idx, col, row, desc, 2);
compute_mat(b, lane_id, idx, col, row, alternate_tiling_order, 2);
}
static void
@ -288,19 +348,26 @@ compute_matrix_offsets(struct nir_builder *b, struct glsl_cmat_description desc,
unsigned idx, nir_def **col_offset, nir_def **row_offset)
{
enum nak_matrix_type_layout cmat_type = determine_matrix_type(desc);
unsigned bit_size = glsl_base_type_bit_size(desc.element_type);
bool uses_movm = uses_movm_for_bit_size(bit_size);
bool alternate_tiling_order =
(uses_movm && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) ||
(!uses_movm && desc.use == GLSL_CMAT_USE_B);
switch (cmat_type) {
case NAK_MAT_16x32_INT8:
compute_mat_16x32_int8(b, lane_id, idx, col_offset, row_offset, desc);
compute_mat_16x32_int8(b, lane_id, idx, col_offset, row_offset, alternate_tiling_order);
break;
case NAK_MAT_16X16:
compute_mat_16x16(b, lane_id, idx, col_offset, row_offset, desc);
compute_mat_16x16(b, lane_id, idx, col_offset, row_offset, alternate_tiling_order);
break;
}
/* The layout calculation code relies on col and row being swapped for B
* row-major and non B col-major matrices.
*/
if (!uses_movm) {
if ((desc.use == GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ||
(desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR)) {
nir_def *tmp = *col_offset;
@ -308,6 +375,7 @@ compute_matrix_offsets(struct nir_builder *b, struct glsl_cmat_description desc,
*row_offset = tmp;
}
}
}
/* Returns the hw native Matrix muladd operation */
static enum nak_cmat_type
@ -678,11 +746,16 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr)
static int load_store_get_vec_size(const struct glsl_cmat_description desc,
enum glsl_matrix_layout layout)
{
if ((desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) ||
(desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR))
unsigned bit_size = glsl_base_type_bit_size(desc.element_type);
bool uses_movm = uses_movm_for_bit_size(bit_size);
bool needs_transpose =
(desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) ||
(desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR);
if (needs_transpose && !uses_movm)
return 1;
switch (glsl_base_type_bit_size(desc.element_type)) {
switch (bit_size) {
case 16:
case 32:
return 2;
@ -783,6 +856,9 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0);
nir_def *value = nir_load_deref(b, iter_deref);
if (transpose_on_load_store(desc, layout))
value = transpose_matrix(b, value);
for (int c = 0; c < vec_size; c++)
vars[idx + c] = nir_channel(b, value, c);
}
@ -856,6 +932,8 @@ lower_cmat_instr(nir_builder *b,
&iter_deref->def, modes, vec_type,
0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0);
nir_def *value = nir_vec(b, &vars[idx], vec_size);
if (transpose_on_load_store(desc, layout))
value = transpose_matrix(b, value);
nir_store_deref(b, iter_deref, value, -1);
}