nak/cmat: add optimisation to cmat load/store to do 32-bit load for f16vec2

Initial idea and code from Dave, but this is a complete rewrite of the
patch.

The Matrix layouts contain groups of values, for int8 we have vec4 groups,
for fp16, fp32 and int32 we have vec2s. With this we load and store them
as vectors getting rid of a bunch of address calculation.

Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37998>
This commit is contained in:
Karol Herbst 2025-10-22 12:54:12 +02:00 committed by Marge Bot
parent 79b3386810
commit 6e89dc33fe

View file

@ -671,6 +671,28 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr)
.matrix_layout = layout);
}
/**
* Returns the possibly vectorization width we can use to load/store matrices
* of the given cmat desc and layout
*/
static int load_store_get_vec_size(const struct glsl_cmat_description desc,
enum glsl_matrix_layout layout)
{
if ((desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) ||
(desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR))
return 1;
switch (glsl_base_type_bit_size(desc.element_type)) {
case 16:
case 32:
return 2;
case 8:
return 4;
default:
return 1;
}
}
static nir_deref_instr*
get_cmat_component_deref(nir_builder *b, nir_intrinsic_instr *intr,
nir_def *lane_id, unsigned idx)
@ -741,6 +763,7 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
}
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]);
const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
const unsigned length = get_cmat_length(desc);
nir_def *vars[NIR_MAX_VEC_COMPONENTS];
@ -749,10 +772,19 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
nir_def *lane_id = nir_load_subgroup_invocation(b);
for (unsigned idx = 0; idx < length; idx++) {
int vec_size = load_store_get_vec_size(desc, layout);
for (unsigned idx = 0; idx < length; idx += vec_size) {
nir_deref_instr *iter_deref =
get_cmat_component_deref(b, intr, lane_id, idx);
vars[idx] = nir_load_deref(b, iter_deref);
nir_variable_mode modes = iter_deref->modes;
const glsl_type *vec_type = glsl_vector_type(desc.element_type, vec_size);
iter_deref = nir_build_deref_cast_with_alignment(b,
&iter_deref->def, modes, vec_type,
0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0);
nir_def *value = nir_load_deref(b, iter_deref);
for (int c = 0; c < vec_size; c++)
vars[idx + c] = nir_channel(b, value, c);
}
nir_def *mat = nir_vec(b, vars, length);
@ -803,6 +835,7 @@ lower_cmat_instr(nir_builder *b,
case nir_intrinsic_cmat_store: {
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[1]);
const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
const unsigned length = get_cmat_length(desc);
nir_def *src = load_cmat_src(b, intr->src[1]);
@ -812,10 +845,18 @@ lower_cmat_instr(nir_builder *b,
nir_def *lane_id = nir_load_subgroup_invocation(b);
for (unsigned idx = 0; idx < length; idx++) {
int vec_size = load_store_get_vec_size(desc, layout);
for (unsigned idx = 0; idx < length; idx += vec_size) {
nir_deref_instr *iter_deref =
get_cmat_component_deref(b, intr, lane_id, idx);
nir_store_deref(b, iter_deref, vars[idx], 1);
nir_variable_mode modes = iter_deref->modes;
const glsl_type *vec_type = glsl_vector_type(desc.element_type, vec_size);
iter_deref = nir_build_deref_cast_with_alignment(b,
&iter_deref->def, modes, vec_type,
0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0);
nir_def *value = nir_vec(b, &vars[idx], vec_size);
nir_store_deref(b, iter_deref, value, -1);
}
nir_instr_remove(instr);