From bf9ad36f2dfffa3567e67f0da3f0f44c71a7b011 Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Thu, 3 Apr 2025 11:05:28 -0700 Subject: [PATCH] brw: Properly handle cooperative matrices created with constants Expand constant sources to cover the region read by DPAS, and also use NULL register as accumulator when possible. Reviewed-by: Sushma Venkatesh Reddy Reviewed-by: Ian Romanick Part-of: --- src/intel/compiler/brw_from_nir.cpp | 59 +++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/src/intel/compiler/brw_from_nir.cpp b/src/intel/compiler/brw_from_nir.cpp index efbee8ce509..4ba0c814356 100644 --- a/src/intel/compiler/brw_from_nir.cpp +++ b/src/intel/compiler/brw_from_nir.cpp @@ -4697,16 +4697,59 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, const brw_reg_type src_type = brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr)); - dest = retype(dest, dest_type); - brw_reg src0 = retype(get_nir_src(ntb, instr->src[0], 0), dest_type); + brw_reg src[3] = {}; + for (unsigned i = 0; i < ARRAY_SIZE(src); i++) { + nir_src nsrc = instr->src[i]; - brw_builder bld16 = bld.exec_all().group(16, 0); - brw_builder bldn = devinfo->ver >= 20 ? bld16 : bld.exec_all().group(8, 0); + if (!nir_src_is_const(nsrc)) { + src[i] = get_nir_src(ntb, nsrc, 0); + continue; + } - bldn.DPAS(dest, - src0, - retype(get_nir_src(ntb, instr->src[2], 0), src_type), - retype(get_nir_src(ntb, instr->src[1], 0), src_type), + /* A single constant value can be used to fill the entire + * cooperative matrix. In this case get_nir_src() would give a + * uniform value (with stride 0), but DPAS can't use regioning, + * it needs the full data available in the register. + * + * So when a source is a constant, allocate the space necessary + * and fill it with the constant value. Except for + * + * When Src0 is specified as null, it is treated as an + * immediate value of +0. + * + * documented in ACM PRM, Vol 2a, "Dot Product Accumulate Systolic". + */ + const unsigned num_components = nir_src_num_components(nsrc); + const unsigned bit_size = nir_src_bit_size(nsrc); + const nir_const_value *nval = nir_src_as_const_value(instr->src[0]); + + assert(bit_size <= 32); + for (unsigned j = 1; j < num_components; j++) + assert(nval[0].u32 == nval[j].u32); + uint32_t val = nval[0].u32; + + if (i == 0 && val == 0) { + src[i] = brw_null_reg(); + + } else { + unsigned size = bit_size * num_components; + unsigned count = size / 32; + assert(size % 32 == 0); + + src[i] = bld.vgrf(BRW_TYPE_UD, count); + for (unsigned j = 0; j < count; j++) + bld.exec_all().MOV(offset(src[i], bld, j), brw_imm_ud(val)); + } + } + + const unsigned dpas_exec_size = devinfo->ver >= 20 ? 16 : 8; + brw_builder bldn = bld.exec_all().group(dpas_exec_size, 0); + + /* DPAS uses a different source order: Accumulator, B, A. */ + bldn.DPAS(retype(dest, dest_type), + retype(src[0], dest_type), + retype(src[2], src_type), + retype(src[1], src_type), sdepth, rcount) ->saturate = nir_intrinsic_saturate(instr);