diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index a42304ae715..2649e137aa1 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -348,18 +348,28 @@ lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); + const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice); ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice); + const struct glsl_cmat_description desc = + *glsl_get_cmat_description(dst_mat_type); + assert(dst_mat_type == src_a_mat_type); assert(dst_mat_type == src_b_mat_type); + const unsigned bits = glsl_base_type_bit_size(desc.element_type); + const unsigned packing_factor = get_packing_factor(desc, dst_slice->type); + for (unsigned i = 0; i < num_components; i++) { nir_def *val_a = nir_channel(b, src_a, i); nir_def *val_b = nir_channel(b, src_b, i); - results[i] = nir_build_alu2(b, nir_intrinsic_alu_op(intrin), val_a, val_b); + results[i] = + nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin), + nir_unpack_bits(b, val_a, bits), + nir_unpack_bits(b, val_b, bits)), + packing_factor * bits); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),