From 27faffb9765b86e6eadfd6ef8d195700d03ff64f Mon Sep 17 00:00:00 2001 From: Ian Romanick Date: Tue, 26 Mar 2024 16:14:33 -0700 Subject: [PATCH] nir: intel/brw: Change the order of sources for nir_dpas_intel It was by pure luck that all sources (and the result) of nir_dpas_intel had the same number of components. It is possible to support matrix sizes where the accumlator matrix and the result matrix are larger (e.g., 16x8 * 8x16 = 16x16). This breaks all of the assumptions of NIR's infrastructure for code generating intrinsics. Fix the by making the accumulator matrix be the first source. The accumulator and the result will always have the same dimensions (due to rules of matrix multiplication) and the same type (due to restructions of the cooperative matrix extension). This forces them to have the same number of components. This doesn't fix all the potential problems. NIR expects that all 0-sized sources will have the same number of components. This just ensures that the result has the correct number of components. Fixes: 6b14da33ad3 ("intel/fs: nir: Add nir_intrinsic_dpas_intel") Reviewed-by: Jordan Justen Part-of: (cherry picked from commit a8115221e596a8bed7a64799ccc03aa9ad225d92) --- .pick_status.json | 2 +- src/compiler/nir/nir_intrinsics.py | 10 +++++++--- src/intel/compiler/brw_fs_nir.cpp | 18 +++++++++--------- .../brw_nir_lower_cooperative_matrix.c | 2 +- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/.pick_status.json b/.pick_status.json index c1412dacf71..40c35605021 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -1674,7 +1674,7 @@ "description": "nir: intel/brw: Change the order of sources for nir_dpas_intel", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "6b14da33ad3aa8a30ed5e479eace8bc6470095a7", "notes": null diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 9a73af521b9..afa47d3dcac 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -2025,11 +2025,15 @@ system_value("leaf_procedural_intel", 1, bit_sizes=[1]) system_value("btd_shader_type_intel", 1) system_value("ray_query_global_intel", 1, bit_sizes=[64]) -# Source 0: A matrix (type specified by SRC_TYPE) -# Source 1: B matrix (type specified by SRC_TYPE) -# Source 2: Accumulator matrix (type specified by DEST_TYPE) +# Source 0: Accumulator matrix (type specified by DEST_TYPE) +# Source 1: A matrix (type specified by SRC_TYPE) +# Source 2: B matrix (type specified by SRC_TYPE) # # The matrix parameters are the slices owned by the invocation. +# +# The accumulator is source 0 because that is the source the intrinsic +# infrastructure in NIR uses to determine the number of components in the +# result. intrinsic("dpas_intel", dest_comp=0, src_comp=[0, 0, 0], indices=[DEST_TYPE, SRC_TYPE, SATURATE, CMAT_SIGNED_MASK, SYSTOLIC_DEPTH, REPEAT_COUNT], flags=[CAN_ELIMINATE]) diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp index ccdd0fe7db8..399fb5b80ba 100644 --- a/src/intel/compiler/brw_fs_nir.cpp +++ b/src/intel/compiler/brw_fs_nir.cpp @@ -4608,7 +4608,7 @@ fs_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr)); dest = retype(dest, dest_type); - fs_reg src2 = retype(get_nir_src(ntb, instr->src[2]), dest_type); + fs_reg src0 = retype(get_nir_src(ntb, instr->src[0]), dest_type); const fs_reg dest_hf = dest; fs_builder bld8 = bld.exec_all().group(8, 0); @@ -4624,24 +4624,24 @@ fs_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, !s.compiler->lower_dpas) { dest = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); - if (src2.file != ARF) { - const fs_reg src2_hf = src2; + if (src0.file != ARF) { + const fs_reg src0_hf = src0; - src2 = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); + src0 = bld8.vgrf(BRW_REGISTER_TYPE_F, rcount); for (unsigned i = 0; i < 4; i++) { - bld16.MOV(byte_offset(src2, REG_SIZE * i * 2), - byte_offset(src2_hf, REG_SIZE * i)); + bld16.MOV(byte_offset(src0, REG_SIZE * i * 2), + byte_offset(src0_hf, REG_SIZE * i)); } } else { - src2 = retype(src2, BRW_REGISTER_TYPE_F); + src0 = retype(src0, BRW_REGISTER_TYPE_F); } } bld8.DPAS(dest, - src2, + src0, + retype(get_nir_src(ntb, instr->src[2]), src_type), retype(get_nir_src(ntb, instr->src[1]), src_type), - retype(get_nir_src(ntb, instr->src[0]), src_type), sdepth, rcount) ->saturate = nir_intrinsic_saturate(instr); diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index f8743d89691..809aa7f456d 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -649,9 +649,9 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_def *result = nir_dpas_intel(b, packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type), + nir_load_deref(b, accum_slice), nir_load_deref(b, A_slice), nir_load_deref(b, B_slice), - nir_load_deref(b, accum_slice), .dest_type = nir_get_nir_type_for_glsl_base_type(dst_desc.element_type), .src_type = nir_get_nir_type_for_glsl_base_type(src_desc.element_type), .saturate = nir_intrinsic_saturate(intrin),