diff --git a/.pick_status.json b/.pick_status.json index 831aa8468e3..29a705acf9f 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -334,7 +334,7 @@ "description": "intel/brw: Fix handling of cmat_signed_mask", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "6b14da33ad3aa8a30ed5e479eace8bc6470095a7", "notes": null diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index 809aa7f456d..8f1ff3ed0e3 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -646,14 +646,44 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type); const unsigned num_components = glsl_get_vector_elements(dst_slice->type); + const nir_cmat_signed cmat_signed_mask = + nir_intrinsic_cmat_signed_mask(intrin); + + assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == + ((cmat_signed_mask & NIR_CMAT_B_SIGNED) == 0)); + assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == + ((cmat_signed_mask & NIR_CMAT_C_SIGNED) == 0)); + assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == + ((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0)); + + nir_alu_type src_type = + nir_get_nir_type_for_glsl_base_type(src_desc.element_type); + nir_alu_type dest_type = + nir_get_nir_type_for_glsl_base_type(dst_desc.element_type); + + /* For integer types, the signedness is determined by flags on the + * muladd instruction. The types of the sources play no role. Adjust the + * types passed to the dpas_intel intrinsic to match. + */ + if (nir_alu_type_get_base_type(src_type) == nir_type_uint || + nir_alu_type_get_base_type(src_type) == nir_type_int) { + if ((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) { + src_type = nir_alu_type_get_type_size(src_type) | nir_type_uint; + dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_uint; + } else { + src_type = nir_alu_type_get_type_size(src_type) | nir_type_int; + dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_int; + } + } + nir_def *result = nir_dpas_intel(b, packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type), nir_load_deref(b, accum_slice), nir_load_deref(b, A_slice), nir_load_deref(b, B_slice), - .dest_type = nir_get_nir_type_for_glsl_base_type(dst_desc.element_type), - .src_type = nir_get_nir_type_for_glsl_base_type(src_desc.element_type), + .dest_type = dest_type, + .src_type = src_type, .saturate = nir_intrinsic_saturate(intrin), .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin), .systolic_depth = 8,