diff --git a/src/nouveau/compiler/nak_nir_lower_cmat.c b/src/nouveau/compiler/nak_nir_lower_cmat.c index d1787ae12a5..b71a4301824 100644 --- a/src/nouveau/compiler/nak_nir_lower_cmat.c +++ b/src/nouveau/compiler/nak_nir_lower_cmat.c @@ -671,6 +671,28 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr) .matrix_layout = layout); } +/** + * Returns the possibly vectorization width we can use to load/store matrices + * of the given cmat desc and layout + */ +static int load_store_get_vec_size(const struct glsl_cmat_description desc, + enum glsl_matrix_layout layout) +{ + if ((desc.use != GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_ROW_MAJOR) || + (desc.use == GLSL_CMAT_USE_B && layout != GLSL_MATRIX_LAYOUT_COLUMN_MAJOR)) + return 1; + + switch (glsl_base_type_bit_size(desc.element_type)) { + case 16: + case 32: + return 2; + case 8: + return 4; + default: + return 1; + } +} + static nir_deref_instr* get_cmat_component_deref(nir_builder *b, nir_intrinsic_instr *intr, nir_def *lane_id, unsigned idx) @@ -741,6 +763,7 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr) } const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]); + const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); const unsigned length = get_cmat_length(desc); nir_def *vars[NIR_MAX_VEC_COMPONENTS]; @@ -749,10 +772,19 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr) nir_def *lane_id = nir_load_subgroup_invocation(b); - for (unsigned idx = 0; idx < length; idx++) { + int vec_size = load_store_get_vec_size(desc, layout); + for (unsigned idx = 0; idx < length; idx += vec_size) { nir_deref_instr *iter_deref = get_cmat_component_deref(b, intr, lane_id, idx); - vars[idx] = nir_load_deref(b, iter_deref); + nir_variable_mode modes = iter_deref->modes; + const glsl_type *vec_type = glsl_vector_type(desc.element_type, vec_size); + iter_deref = nir_build_deref_cast_with_alignment(b, + &iter_deref->def, modes, vec_type, + 0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0); + + nir_def *value = nir_load_deref(b, iter_deref); + for (int c = 0; c < vec_size; c++) + vars[idx + c] = nir_channel(b, value, c); } nir_def *mat = nir_vec(b, vars, length); @@ -803,6 +835,7 @@ lower_cmat_instr(nir_builder *b, case nir_intrinsic_cmat_store: { const struct glsl_cmat_description desc = cmat_src_desc(intr->src[1]); + const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); const unsigned length = get_cmat_length(desc); nir_def *src = load_cmat_src(b, intr->src[1]); @@ -812,10 +845,18 @@ lower_cmat_instr(nir_builder *b, nir_def *lane_id = nir_load_subgroup_invocation(b); - for (unsigned idx = 0; idx < length; idx++) { + int vec_size = load_store_get_vec_size(desc, layout); + for (unsigned idx = 0; idx < length; idx += vec_size) { nir_deref_instr *iter_deref = get_cmat_component_deref(b, intr, lane_id, idx); - nir_store_deref(b, iter_deref, vars[idx], 1); + + nir_variable_mode modes = iter_deref->modes; + const glsl_type *vec_type = glsl_vector_type(desc.element_type, vec_size); + iter_deref = nir_build_deref_cast_with_alignment(b, + &iter_deref->def, modes, vec_type, + 0, vec_size * glsl_base_type_bit_size(desc.element_type) / 8, 0); + nir_def *value = nir_vec(b, &vars[idx], vec_size); + nir_store_deref(b, iter_deref, value, -1); } nir_instr_remove(instr);