diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index 16a9bd2f5b2..0efc3218f71 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -6,6 +6,46 @@ /** * \file brw_nir_lower_cooperative_matrix.c * Lower cooperative matrix to subgroup operations. + * + * All supported matrix types are assumed to have either 8 rows or 8 + * columns. The other dimension of the matrix is typically 8 times the number + * of data elements that can be stored in a 32-bit dword. Matrix data is + * indexed by a combination of an array element and a subgroup invocation ID. + * + * Two layouts for matrix data are used. In the first layout, + * subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be + * called row-major hereafter. In the other layout, + * subgroupShuffle(slice[...], M) accesses column M of the matrix. This will + * be called column-major hereafter. In cases where a single 32-bit value is + * stored in each entry, these layouts are identical. + * + * The subtle difference arises when multiple values are packed into a single + * 32-bit dword. If two 16-bit values are packed in a single 32-bit value in + * column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and + * m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds + * m[0][2] and m[0][3]. + * + * There is an alternate way to think about the matrix layouts. Every matrix + * size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B + * matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix + * layouts are such that a single 8 dword register hold an entire row of the + * matrix. + * + * Consider a matrix stored starting in register g32. In an A matrix, the + * packed dwords of g32 contain only the data for a single row of the + * matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords + * of g(32+N).X contain only the data for a single column of the + * matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc. + * + * This leads to some shenanigans in \c lower_cmat_load_store. + * + * In the common case, A, C, and result matrices are stored row major while B + * matrices are stored column major. This arrangement facilitates efficient + * dot product operations using DPAS or DP4A instructions. + * + * Future optimizations are possible when row and column major are + * flipped. That is, efficient dot products are also possible when A, C, and + * result matrices are column major while B is row major. */ #include "brw_nir.h" @@ -113,6 +153,10 @@ get_slice_type_from_desc(const struct lower_cmat_state *state, /* Adjust the packing factor so that each row of the matrix fills and * entire GRF. + * + * The in-register layout of B matrices is different, so those are handled + * more like column major (for row major matrices). See the file comment + * for more details. */ const unsigned actual_cols = desc.use != GLSL_CMAT_USE_B ? desc.cols : desc.rows; while ((actual_cols / packing_factor) < 8) { @@ -198,47 +242,134 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, const unsigned mat_src = load ? 0 : 1; const unsigned ptr_src = load ? 1 : 0; - /* TODO: Column major. */ - assert(nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR); - nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]); const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type); - /* TODO: Dynamic stride. */ - assert(nir_src_is_const(intrin->src[2])); - nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(slice->type); + const unsigned packing_factor = get_packing_factor(*desc, slice->type); nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]); - const unsigned stride = nir_src_as_uint(intrin->src[2]); + if ((nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) == + (desc->use != GLSL_CMAT_USE_B)) { + nir_def *stride = nir_udiv_imm(b, intrin->src[2].ssa, packing_factor); - const struct glsl_type *element_type = - glsl_get_array_element(slice->type); + const struct glsl_type *element_type = + glsl_scalar_type(glsl_get_base_type(slice->type)); - const struct glsl_type *pointer_type = - glsl_array_type(element_type, MAX2(desc->rows, desc->cols) * stride, - glsl_get_bit_size(element_type) / 8); + pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, + element_type, + glsl_get_bit_size(element_type) / 8); - pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, pointer_type, - glsl_get_bit_size(element_type) / 8); + nir_def *invocation = nir_load_subgroup_invocation(b); + nir_def *base_offset; + nir_def *step; - for (unsigned i = 0; i < num_components; i++) { + if (desc->use != GLSL_CMAT_USE_B) { + base_offset = nir_iadd(b, + nir_imul(b, + nir_udiv_imm(b, invocation, 8), + stride), + nir_umod_imm(b, invocation, 8)); - nir_def *offset = nir_imul_imm(b, nir_load_subgroup_invocation(b), - stride); - nir_deref_instr *memory_deref = - nir_build_deref_array(b, pointer, - nir_i2iN(b, nir_iadd_imm(b, offset, i), - pointer->def.bit_size)); - - if (load) { - results[i] = nir_load_deref(b, memory_deref); + step = nir_imul_imm(b, stride, state->subgroup_size / 8); } else { - nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); - nir_store_deref(b, memory_deref, src, 0x1); + base_offset = nir_iadd(b, + nir_imul(b, + nir_umod_imm(b, invocation, 8), + stride), + nir_udiv_imm(b, invocation, 8)); + + step = nir_imm_int(b, state->subgroup_size / 8); + } + + for (unsigned i = 0; i < num_components; i++) { + nir_def *offset = nir_imul_imm(b, step, i); + + nir_deref_instr *memory_deref = + nir_build_deref_ptr_as_array(b, pointer, + nir_i2iN(b, + nir_iadd(b, + base_offset, + offset), + pointer->def.bit_size)); + + if (load) { + results[i] = nir_load_deref(b, memory_deref); + } else { + nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); + nir_store_deref(b, memory_deref, src, 0x1); + } + } + } else { + nir_def *stride = intrin->src[2].ssa; + + const struct glsl_type *element_type = glsl_scalar_type(desc->element_type); + const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type); + const unsigned element_stride = element_bits / 8; + + pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type, + element_stride); + + nir_def *invocation_div_8 = nir_udiv_imm(b, nir_load_subgroup_invocation(b), 8); + nir_def *invocation_mod_8 = nir_umod_imm(b, nir_load_subgroup_invocation(b), 8); + + nir_def *packed_stride = nir_imul_imm(b, stride, packing_factor); + + for (unsigned i = 0; i < num_components; i++) { + const unsigned i_offset = i * (state->subgroup_size / 8); + nir_def *v[4]; + + for (unsigned j = 0; j < packing_factor; j++) { + nir_def *j_offset = nir_imul_imm(b, stride, j); + nir_def *offset; + + if (desc->use != GLSL_CMAT_USE_B) { + offset = nir_iadd(b, + nir_iadd(b, + nir_imul(b, + invocation_mod_8, + packed_stride), + invocation_div_8), + nir_iadd_imm(b, j_offset, i_offset)); + } else { + offset = nir_iadd(b, + nir_iadd(b, + nir_imul(b, + invocation_div_8, + packed_stride), + invocation_mod_8), + nir_iadd(b, + nir_imul_imm(b, + packed_stride, + i_offset), + j_offset)); + } + + nir_deref_instr *memory_deref = + nir_build_deref_ptr_as_array(b, pointer, + nir_i2iN(b, + offset, + pointer->def.bit_size)); + + if (load) { + v[j] = nir_load_deref(b, memory_deref); + } else { + nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); + + nir_def *v = + nir_channel(b, nir_unpack_bits(b, src, element_bits), j); + + nir_store_deref(b, memory_deref, v, 0x1); + } + } + + if (load) { + results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor), + packing_factor * element_bits); + } } }