nak: use ldsm

Reviewed-by: Mary Guillemard <mary@mary.zone>
Acked-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36363>
This commit is contained in:
Karol Herbst 2025-07-14 12:30:16 +02:00
parent c38170452d
commit 26c1ded905

View file

@ -570,9 +570,116 @@ lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intr, nir_def *cmat,
return cmat;
}
static struct nir_def*
try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr)
{
assert(intr->intrinsic == nir_intrinsic_cmat_load);
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]);
const unsigned length = get_cmat_length(desc);
nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
const unsigned ptr_bit_size = glsl_get_bit_size(deref->type);
const unsigned vec = glsl_get_vector_elements(deref->type);
nir_src stride = intr->src[2];
/* Even though LDSM operates on 16 bit types, the int8 matrix layout is
* compatible so that we can use LDSM on it as well. But we can't use it on
* the 32 bit types, because that actually uses a different data layout on a
* byte level.
*/
const unsigned bit_size = glsl_base_type_bit_size(desc.element_type);
if (!nir_src_is_const(stride)
|| !nir_deref_mode_is(deref, nir_var_mem_shared)
|| bit_size > 16)
return NULL;
/* The stride is in elements of the pointed to type, not necessarily the
* type of the referenced matrix
*/
unsigned stride_bytes = nir_src_as_uint(stride) * vec * ptr_bit_size / 8;
if (stride_bytes % 16 != 0)
return NULL;
/* check implicit base ptr alignment */
if ((layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR && desc.cols * bit_size < 128) ||
(layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR && desc.rows * bit_size < 128))
return NULL;
/* LDSM loads n 8x8 16 bit matrices */
unsigned mat_size_bits = desc.rows * desc.cols * bit_size;
unsigned ldsm_count = mat_size_bits / (8 * 8 * 16);
/* TODO: split bigger ones into multiple LDSM calls */
if (ldsm_count > 4 || ldsm_count == 0)
return NULL;
if ((desc.use != GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) ||
(desc.use == GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR)) {
/* Quite the pain, might not be worth it */
if (ldsm_count >= 4)
return NULL;
/* We'd need to split the rows leading to unaligned loads */
if (ldsm_count >= 2 && (desc.rows / 2) * bit_size < 128)
return NULL;
}
/* Account for differences in tiling depending on the layout */
nir_def *offset;
nir_def *lane_id = nir_load_subgroup_invocation(b);
if (ldsm_count == 4 && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) {
nir_def *lower = nir_iand(b, lane_id, nir_imm_int(b, 0x0f));
nir_def *upper = nir_iand(b, lane_id, nir_imm_int(b, 0x10));
offset = nir_imul_imm(b, lower, stride_bytes);
offset = nir_iadd(b, offset, upper);
} else if (ldsm_count >= 2 && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) {
nir_def *lower;
nir_def *lower_lo = nir_iand(b, lane_id, nir_imm_int(b, 0x07));
nir_def *upper = nir_iand(b, lane_id, nir_imm_int(b, 0x08));
if (ldsm_count == 4) {
nir_def *lower_hi = nir_iand(b, lane_id, nir_imm_int(b, 0x10));
lower = nir_ior(b, lower_lo, nir_ushr_imm(b, lower_hi, 1));
} else {
lower = lower_lo;
}
offset = nir_imul_imm(b, lower, stride_bytes);
offset = nir_iadd(b, offset, nir_ishl_imm(b, upper, 1));
} else {
offset = nir_imul_imm(b, lane_id, stride_bytes);
}
nir_def *base = intr->src[1].ssa;
offset = nir_u2uN(b, offset, base->bit_size);
nir_def *addr = nir_iadd(b, base, offset);
/* flip the layout for B matrices */
if (desc.use == GLSL_CMAT_USE_B) {
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR)
layout = GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
else if (layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR)
layout = GLSL_MATRIX_LAYOUT_ROW_MAJOR;
}
/* Each thread loads 32 bits per matrix */
assert(length * bit_size == 32 * ldsm_count);
return nir_cmat_load_shared_nv(b, length, bit_size, addr,
.num_matrices = ldsm_count,
.matrix_layout = layout);
}
static void
lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
{
struct nir_def *ldsm = try_lower_cmat_load_to_ldsm(b, intr);
if (ldsm) {
store_cmat_src(b, intr->src[0], ldsm);
return;
}
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]);
const unsigned length = get_cmat_length(desc);
const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);