nak: rework scale argument of compute_mat and rename it

"scale" was a bad name as it meant nothing others could comprehend.
However that value corresponded to the tile size of the Matrix layouts.
For int8 we have a tile size of 4, for 16 and 32 bit values we have a size
of 2.

For the future the same value need to be 32 for booleans, 8 for int4, 1
for fp64 and tf32 and 4 for all em types.

The "scale = 1 << scale;" assignment can simply be removed because it was
4 for 2 and 2 for 1 simply being the expected value after this change.

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36391>
This commit is contained in:
Karol Herbst 2025-07-26 14:28:39 +02:00
parent d0f4b535fe
commit b9f060438f

View file

@ -232,36 +232,38 @@ remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig)
return new_type;
}
/* Computes the index in a linear matrix buffer a thread needs to load from in
/**
* Computes the index in a linear matrix buffer a thread needs to load from in
* order to execute an MMA on the Matrix.
*
* This is a generalized formula based on the Matrix layout descriptions from
* the CUDA PTX instruction set documentation:
* https://docs.nvidia.com/cuda/archive/12.8.1/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
*
* \param group_size Size of the value groups the layout tiles around.
*/
static void
compute_mat(struct nir_builder *b, nir_def *lane_id,
unsigned idx, nir_def **col, nir_def **row,
struct glsl_cmat_description desc,
unsigned scale)
unsigned group_size)
{
assert(idx < 8 * scale);
assert(idx < 4 * group_size);
nir_def *quad_id = nir_ushr_imm(b, lane_id, 2);
nir_def *thread_id_in_quad = nir_iand_imm(b, lane_id, 0x3);
unsigned row_bound = (desc.use == GLSL_CMAT_USE_B ? 4 : 2) * scale;
unsigned col_bound = (desc.use == GLSL_CMAT_USE_B ? 2 : 4) * scale;
unsigned row_bound = (desc.use == GLSL_CMAT_USE_B ? 2 : 1) * group_size;
unsigned col_bound = (desc.use == GLSL_CMAT_USE_B ? 1 : 2) * group_size;
scale = 1 << scale;
*row = quad_id;
if (idx & row_bound)
*row = nir_iadd_imm(b, *row, 8);
*col = nir_iadd_imm(b, nir_imul_imm(b, thread_id_in_quad, scale),
idx & (scale - 1));
*col = nir_iadd_imm(b, nir_imul_imm(b, thread_id_in_quad, group_size),
idx & (group_size - 1));
if (idx & col_bound)
*col = nir_iadd_imm(b, *col, scale * 4);
*col = nir_iadd_imm(b, *col, group_size * 4);
}
static void
@ -269,7 +271,7 @@ compute_mat_16x32_int8(struct nir_builder *b, nir_def *lane_id,
unsigned idx, nir_def **col, nir_def **row,
struct glsl_cmat_description desc)
{
compute_mat(b, lane_id, idx, col, row, desc, 2);
compute_mat(b, lane_id, idx, col, row, desc, 4);
}
static void
@ -277,7 +279,7 @@ compute_mat_16x16(struct nir_builder *b, nir_def *lane_id,
unsigned idx, nir_def **col, nir_def **row,
struct glsl_cmat_description desc)
{
compute_mat(b, lane_id, idx, col, row, desc, 1);
compute_mat(b, lane_id, idx, col, row, desc, 2);
}
static void