diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index 6887d37b841..a42304ae715 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -252,13 +252,83 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); - nir_def *src = nir_load_deref(b, nir_src_as_deref(intrin->src[1])); + nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - for (unsigned i = 0; i < num_components; i++) { - nir_def *val = nir_channel(b, src, i); - results[i] = nir_build_alu1(b, nir_intrinsic_alu_op(intrin), val); + const struct glsl_type *dst_mat_type = + get_coop_type_for_slice(state, dst_slice); + const struct glsl_type *src_mat_type = + get_coop_type_for_slice(state, src_slice); + + const struct glsl_cmat_description dst_desc = + *glsl_get_cmat_description(dst_mat_type); + + const struct glsl_cmat_description src_desc = + *glsl_get_cmat_description(src_mat_type); + + const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type); + const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type); + + /* The type of the returned slice may be different from the type of the + * input slice. + */ + const unsigned dst_packing_factor = + get_packing_factor(dst_desc, dst_slice->type); + + const unsigned src_packing_factor = + get_packing_factor(src_desc, src_slice->type); + + const nir_op op = nir_intrinsic_alu_op(intrin); + + /* There are three possible cases: + * + * 1. dst_packing_factor == src_packing_factor. This is the common case, + * and handling it is straightforward. + * + * 2. dst_packing_factor > src_packing_factor. This occurs when converting a + * float32_t matrix slice to a packed float16_t slice. Loop over the size + * of the destination slice, but read multiple entries from the source + * slice on each iteration. + * + * 3. dst_packing_factor < src_packing_factor. This occurs when converting a + * packed int8_t matrix slice to an int32_t slice. Loop over the size of + * the source slice, but write multiple entries to the destination slice + * on each iteration. + * + * Handle all cases by iterating over the total (non-packed) number of + * elements in the slice. When dst_packing_factor values have been + * calculated, store them. + */ + assert((dst_packing_factor * glsl_get_vector_elements(dst_slice->type)) == + (src_packing_factor * glsl_get_vector_elements(src_slice->type))); + + /* Stores at most dst_packing_factor partial results. */ + nir_def *v[4]; + assert(dst_packing_factor <= 4); + + for (unsigned i = 0; i < num_components * dst_packing_factor; i++) { + const unsigned dst_chan_index = i % dst_packing_factor; + const unsigned src_chan_index = i % src_packing_factor; + const unsigned dst_index = i / dst_packing_factor; + const unsigned src_index = i / src_packing_factor; + + nir_def *src = + nir_channel(b, + nir_unpack_bits(b, + nir_channel(b, + nir_load_deref(b, src_slice), + src_index), + src_bits), + src_chan_index); + + v[dst_chan_index] = nir_build_alu1(b, op, src); + + if (dst_chan_index == (dst_packing_factor - 1)) { + results[dst_index] = + nir_pack_bits(b, nir_vec(b, v, dst_packing_factor), + dst_packing_factor * dst_bits); + } } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), @@ -362,6 +432,17 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) case nir_intrinsic_cmat_construct: { nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); nir_def *src = intrin->src[1].ssa; + + const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); + const struct glsl_cmat_description desc = + *glsl_get_cmat_description(mat_type); + const unsigned packing_factor = get_packing_factor(desc, slice->type); + + if (packing_factor > 1) { + src = nir_pack_bits(b, nir_replicate(b, src, packing_factor), + packing_factor * glsl_base_type_get_bit_size(desc.element_type)); + } + const unsigned num_components = glsl_get_vector_elements(slice->type); nir_store_deref(b, slice, nir_replicate(b, src, num_components), @@ -385,7 +466,8 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin); const struct glsl_type *mat_type = glsl_cmat_type(&desc); const struct glsl_type *slice_type = get_slice_type(state, mat_type); - return nir_imm_intN_t(b, glsl_get_vector_elements(slice_type), 32); + return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) * + glsl_get_vector_elements(slice_type)), 32); } case nir_intrinsic_cmat_muladd: