diff --git a/src/nouveau/compiler/nak_nir_lower_cmat.c b/src/nouveau/compiler/nak_nir_lower_cmat.c index 1eac30da246..313a7f54061 100644 --- a/src/nouveau/compiler/nak_nir_lower_cmat.c +++ b/src/nouveau/compiler/nak_nir_lower_cmat.c @@ -570,9 +570,116 @@ lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intr, nir_def *cmat, return cmat; } +static struct nir_def* +try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr) +{ + assert(intr->intrinsic == nir_intrinsic_cmat_load); + + enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); + + const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]); + const unsigned length = get_cmat_length(desc); + nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr); + const unsigned ptr_bit_size = glsl_get_bit_size(deref->type); + const unsigned vec = glsl_get_vector_elements(deref->type); + nir_src stride = intr->src[2]; + + /* Even though LDSM operates on 16 bit types, the int8 matrix layout is + * compatible so that we can use LDSM on it as well. But we can't use it on + * the 32 bit types, because that actually uses a different data layout on a + * byte level. + */ + const unsigned bit_size = glsl_base_type_bit_size(desc.element_type); + if (!nir_src_is_const(stride) + || !nir_deref_mode_is(deref, nir_var_mem_shared) + || bit_size > 16) + return NULL; + + /* The stride is in elements of the pointed to type, not necessarily the + * type of the referenced matrix + */ + unsigned stride_bytes = nir_src_as_uint(stride) * vec * ptr_bit_size / 8; + if (stride_bytes % 16 != 0) + return NULL; + + /* check implicit base ptr alignment */ + if ((layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR && desc.cols * bit_size < 128) || + (layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR && desc.rows * bit_size < 128)) + return NULL; + + /* LDSM loads n 8x8 16 bit matrices */ + unsigned mat_size_bits = desc.rows * desc.cols * bit_size; + unsigned ldsm_count = mat_size_bits / (8 * 8 * 16); + + /* TODO: split bigger ones into multiple LDSM calls */ + if (ldsm_count > 4 || ldsm_count == 0) + return NULL; + + if ((desc.use != GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) || + (desc.use == GLSL_CMAT_USE_B && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR)) { + /* Quite the pain, might not be worth it */ + if (ldsm_count >= 4) + return NULL; + + /* We'd need to split the rows leading to unaligned loads */ + if (ldsm_count >= 2 && (desc.rows / 2) * bit_size < 128) + return NULL; + } + + /* Account for differences in tiling depending on the layout */ + nir_def *offset; + nir_def *lane_id = nir_load_subgroup_invocation(b); + if (ldsm_count == 4 && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) { + nir_def *lower = nir_iand(b, lane_id, nir_imm_int(b, 0x0f)); + nir_def *upper = nir_iand(b, lane_id, nir_imm_int(b, 0x10)); + + offset = nir_imul_imm(b, lower, stride_bytes); + offset = nir_iadd(b, offset, upper); + } else if (ldsm_count >= 2 && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) { + nir_def *lower; + nir_def *lower_lo = nir_iand(b, lane_id, nir_imm_int(b, 0x07)); + nir_def *upper = nir_iand(b, lane_id, nir_imm_int(b, 0x08)); + if (ldsm_count == 4) { + nir_def *lower_hi = nir_iand(b, lane_id, nir_imm_int(b, 0x10)); + lower = nir_ior(b, lower_lo, nir_ushr_imm(b, lower_hi, 1)); + } else { + lower = lower_lo; + } + + offset = nir_imul_imm(b, lower, stride_bytes); + offset = nir_iadd(b, offset, nir_ishl_imm(b, upper, 1)); + } else { + offset = nir_imul_imm(b, lane_id, stride_bytes); + } + + nir_def *base = intr->src[1].ssa; + offset = nir_u2uN(b, offset, base->bit_size); + nir_def *addr = nir_iadd(b, base, offset); + + /* flip the layout for B matrices */ + if (desc.use == GLSL_CMAT_USE_B) { + if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) + layout = GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; + else if (layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) + layout = GLSL_MATRIX_LAYOUT_ROW_MAJOR; + } + + /* Each thread loads 32 bits per matrix */ + assert(length * bit_size == 32 * ldsm_count); + return nir_cmat_load_shared_nv(b, length, bit_size, addr, + .num_matrices = ldsm_count, + .matrix_layout = layout); +} + static void lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr) { + struct nir_def *ldsm = try_lower_cmat_load_to_ldsm(b, intr); + if (ldsm) { + store_cmat_src(b, intr->src[0], ldsm); + return; + } + const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]); const unsigned length = get_cmat_length(desc); const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);