nak/cmat: add alignment info to matrix load/stores

Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37998>
This commit is contained in:
Karol Herbst 2025-11-04 12:37:03 +01:00 committed by Marge Bot
parent a643681dd5
commit 79b3386810

View file

@ -680,6 +680,7 @@ get_cmat_component_deref(nir_builder *b, nir_intrinsic_instr *intr,
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[cmat_src]);
nir_deref_instr *deref = nir_def_as_deref(intr->src[deref_src].ssa);
unsigned type_size_B = glsl_base_type_bit_size(desc.element_type) / 8;
const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_def *stride = intr->src[2].ssa;
@ -693,18 +694,40 @@ get_cmat_component_deref(nir_builder *b, nir_intrinsic_instr *intr,
col_offset = nir_u2uN(b, col_offset, deref->def.bit_size);
row_offset = nir_u2uN(b, row_offset, deref->def.bit_size);
unsigned align_mul = 0, align_offset = 0, combined_align = 0;
nir_get_explicit_deref_align(deref, false, &align_mul, &align_offset);
if (align_mul)
combined_align = nir_combined_align(align_mul, align_offset);
/* VUID-RuntimeSpirv-OpCooperativeMatrixLoadKHR-08986:
* For OpCooperativeMatrixLoadKHR and OpCooperativeMatrixStoreKHR
* instructions, the Pointer and Stride operands must be aligned to at least
* the lesser of 16 bytes or the natural alignment of a row or column
* (depending on ColumnMajor) of the matrix (where the natural alignment is
* the number of columns/rows multiplied by the component size) */
unsigned align_elems =
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? desc.rows : desc.cols;
unsigned implicit_align = MIN2(16, align_elems * type_size_B);
if (implicit_align > combined_align) {
align_mul = implicit_align;
align_offset = 0;
}
/* We have to ignore the incoming stride, but have to choose the type of
* the pointer as the declared stride is in multiple of the pointer type */
deref = nir_build_deref_cast(
deref = nir_build_deref_cast_with_alignment(
b, &deref->def, deref->modes,
deref->type,
glsl_get_vector_elements(deref->type) * glsl_get_bit_size(deref->type) / 8
glsl_get_vector_elements(deref->type) * glsl_get_bit_size(deref->type) / 8,
align_mul,
align_offset
);
deref = nir_build_deref_ptr_as_array(b, deref, row_offset);
deref = nir_build_deref_cast(
b, &deref->def, deref->modes,
glsl_scalar_type(desc.element_type),
glsl_base_type_bit_size(desc.element_type) / 8);
type_size_B);
return nir_build_deref_ptr_as_array(b, deref, col_offset);
}