mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-20 07:20:10 +01:00
intel/cmat: Lower cmat_load and cmat_store
v2: Add support for non-constant stride. v3: Explain B matrices (a little bit) in get_slice_type_from_desc. Suggested by Caio. Reviewed-by: Caio Oliveira <caio.oliveira@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25994>
This commit is contained in:
parent
502be565da
commit
3a35f8b29b
1 changed files with 157 additions and 26 deletions
|
|
@ -6,6 +6,46 @@
|
|||
/**
|
||||
* \file brw_nir_lower_cooperative_matrix.c
|
||||
* Lower cooperative matrix to subgroup operations.
|
||||
*
|
||||
* All supported matrix types are assumed to have either 8 rows or 8
|
||||
* columns. The other dimension of the matrix is typically 8 times the number
|
||||
* of data elements that can be stored in a 32-bit dword. Matrix data is
|
||||
* indexed by a combination of an array element and a subgroup invocation ID.
|
||||
*
|
||||
* Two layouts for matrix data are used. In the first layout,
|
||||
* subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be
|
||||
* called row-major hereafter. In the other layout,
|
||||
* subgroupShuffle(slice[...], M) accesses column M of the matrix. This will
|
||||
* be called column-major hereafter. In cases where a single 32-bit value is
|
||||
* stored in each entry, these layouts are identical.
|
||||
*
|
||||
* The subtle difference arises when multiple values are packed into a single
|
||||
* 32-bit dword. If two 16-bit values are packed in a single 32-bit value in
|
||||
* column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and
|
||||
* m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds
|
||||
* m[0][2] and m[0][3].
|
||||
*
|
||||
* There is an alternate way to think about the matrix layouts. Every matrix
|
||||
* size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B
|
||||
* matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix
|
||||
* layouts are such that a single 8 dword register hold an entire row of the
|
||||
* matrix.
|
||||
*
|
||||
* Consider a matrix stored starting in register g32. In an A matrix, the
|
||||
* packed dwords of g32 contain only the data for a single row of the
|
||||
* matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords
|
||||
* of g(32+N).X contain only the data for a single column of the
|
||||
* matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc.
|
||||
*
|
||||
* This leads to some shenanigans in \c lower_cmat_load_store.
|
||||
*
|
||||
* In the common case, A, C, and result matrices are stored row major while B
|
||||
* matrices are stored column major. This arrangement facilitates efficient
|
||||
* dot product operations using DPAS or DP4A instructions.
|
||||
*
|
||||
* Future optimizations are possible when row and column major are
|
||||
* flipped. That is, efficient dot products are also possible when A, C, and
|
||||
* result matrices are column major while B is row major.
|
||||
*/
|
||||
|
||||
#include "brw_nir.h"
|
||||
|
|
@ -113,6 +153,10 @@ get_slice_type_from_desc(const struct lower_cmat_state *state,
|
|||
|
||||
/* Adjust the packing factor so that each row of the matrix fills and
|
||||
* entire GRF.
|
||||
*
|
||||
* The in-register layout of B matrices is different, so those are handled
|
||||
* more like column major (for row major matrices). See the file comment
|
||||
* for more details.
|
||||
*/
|
||||
const unsigned actual_cols = desc.use != GLSL_CMAT_USE_B ? desc.cols : desc.rows;
|
||||
while ((actual_cols / packing_factor) < 8) {
|
||||
|
|
@ -198,47 +242,134 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
const unsigned mat_src = load ? 0 : 1;
|
||||
const unsigned ptr_src = load ? 1 : 0;
|
||||
|
||||
/* TODO: Column major. */
|
||||
assert(nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR);
|
||||
|
||||
nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]);
|
||||
const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
|
||||
const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type);
|
||||
|
||||
/* TODO: Dynamic stride. */
|
||||
assert(nir_src_is_const(intrin->src[2]));
|
||||
|
||||
nir_def *results[NIR_MAX_VEC_COMPONENTS];
|
||||
const unsigned num_components = glsl_get_vector_elements(slice->type);
|
||||
const unsigned packing_factor = get_packing_factor(*desc, slice->type);
|
||||
|
||||
nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]);
|
||||
|
||||
const unsigned stride = nir_src_as_uint(intrin->src[2]);
|
||||
if ((nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
|
||||
(desc->use != GLSL_CMAT_USE_B)) {
|
||||
nir_def *stride = nir_udiv_imm(b, intrin->src[2].ssa, packing_factor);
|
||||
|
||||
const struct glsl_type *element_type =
|
||||
glsl_get_array_element(slice->type);
|
||||
const struct glsl_type *element_type =
|
||||
glsl_scalar_type(glsl_get_base_type(slice->type));
|
||||
|
||||
const struct glsl_type *pointer_type =
|
||||
glsl_array_type(element_type, MAX2(desc->rows, desc->cols) * stride,
|
||||
glsl_get_bit_size(element_type) / 8);
|
||||
pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes,
|
||||
element_type,
|
||||
glsl_get_bit_size(element_type) / 8);
|
||||
|
||||
pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, pointer_type,
|
||||
glsl_get_bit_size(element_type) / 8);
|
||||
nir_def *invocation = nir_load_subgroup_invocation(b);
|
||||
nir_def *base_offset;
|
||||
nir_def *step;
|
||||
|
||||
for (unsigned i = 0; i < num_components; i++) {
|
||||
if (desc->use != GLSL_CMAT_USE_B) {
|
||||
base_offset = nir_iadd(b,
|
||||
nir_imul(b,
|
||||
nir_udiv_imm(b, invocation, 8),
|
||||
stride),
|
||||
nir_umod_imm(b, invocation, 8));
|
||||
|
||||
nir_def *offset = nir_imul_imm(b, nir_load_subgroup_invocation(b),
|
||||
stride);
|
||||
nir_deref_instr *memory_deref =
|
||||
nir_build_deref_array(b, pointer,
|
||||
nir_i2iN(b, nir_iadd_imm(b, offset, i),
|
||||
pointer->def.bit_size));
|
||||
|
||||
if (load) {
|
||||
results[i] = nir_load_deref(b, memory_deref);
|
||||
step = nir_imul_imm(b, stride, state->subgroup_size / 8);
|
||||
} else {
|
||||
nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
|
||||
nir_store_deref(b, memory_deref, src, 0x1);
|
||||
base_offset = nir_iadd(b,
|
||||
nir_imul(b,
|
||||
nir_umod_imm(b, invocation, 8),
|
||||
stride),
|
||||
nir_udiv_imm(b, invocation, 8));
|
||||
|
||||
step = nir_imm_int(b, state->subgroup_size / 8);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < num_components; i++) {
|
||||
nir_def *offset = nir_imul_imm(b, step, i);
|
||||
|
||||
nir_deref_instr *memory_deref =
|
||||
nir_build_deref_ptr_as_array(b, pointer,
|
||||
nir_i2iN(b,
|
||||
nir_iadd(b,
|
||||
base_offset,
|
||||
offset),
|
||||
pointer->def.bit_size));
|
||||
|
||||
if (load) {
|
||||
results[i] = nir_load_deref(b, memory_deref);
|
||||
} else {
|
||||
nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
|
||||
nir_store_deref(b, memory_deref, src, 0x1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nir_def *stride = intrin->src[2].ssa;
|
||||
|
||||
const struct glsl_type *element_type = glsl_scalar_type(desc->element_type);
|
||||
const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type);
|
||||
const unsigned element_stride = element_bits / 8;
|
||||
|
||||
pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type,
|
||||
element_stride);
|
||||
|
||||
nir_def *invocation_div_8 = nir_udiv_imm(b, nir_load_subgroup_invocation(b), 8);
|
||||
nir_def *invocation_mod_8 = nir_umod_imm(b, nir_load_subgroup_invocation(b), 8);
|
||||
|
||||
nir_def *packed_stride = nir_imul_imm(b, stride, packing_factor);
|
||||
|
||||
for (unsigned i = 0; i < num_components; i++) {
|
||||
const unsigned i_offset = i * (state->subgroup_size / 8);
|
||||
nir_def *v[4];
|
||||
|
||||
for (unsigned j = 0; j < packing_factor; j++) {
|
||||
nir_def *j_offset = nir_imul_imm(b, stride, j);
|
||||
nir_def *offset;
|
||||
|
||||
if (desc->use != GLSL_CMAT_USE_B) {
|
||||
offset = nir_iadd(b,
|
||||
nir_iadd(b,
|
||||
nir_imul(b,
|
||||
invocation_mod_8,
|
||||
packed_stride),
|
||||
invocation_div_8),
|
||||
nir_iadd_imm(b, j_offset, i_offset));
|
||||
} else {
|
||||
offset = nir_iadd(b,
|
||||
nir_iadd(b,
|
||||
nir_imul(b,
|
||||
invocation_div_8,
|
||||
packed_stride),
|
||||
invocation_mod_8),
|
||||
nir_iadd(b,
|
||||
nir_imul_imm(b,
|
||||
packed_stride,
|
||||
i_offset),
|
||||
j_offset));
|
||||
}
|
||||
|
||||
nir_deref_instr *memory_deref =
|
||||
nir_build_deref_ptr_as_array(b, pointer,
|
||||
nir_i2iN(b,
|
||||
offset,
|
||||
pointer->def.bit_size));
|
||||
|
||||
if (load) {
|
||||
v[j] = nir_load_deref(b, memory_deref);
|
||||
} else {
|
||||
nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
|
||||
|
||||
nir_def *v =
|
||||
nir_channel(b, nir_unpack_bits(b, src, element_bits), j);
|
||||
|
||||
nir_store_deref(b, memory_deref, v, 0x1);
|
||||
}
|
||||
}
|
||||
|
||||
if (load) {
|
||||
results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor),
|
||||
packing_factor * element_bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue