brw: Properly handle cooperative matrices created with constants
Some checks are pending
macOS-CI / macOS-CI (dri) (push) Waiting to run
macOS-CI / macOS-CI (xlib) (push) Waiting to run

Expand constant sources to cover the region read by DPAS, and also
use NULL register as accumulator when possible.

Reviewed-by: Sushma Venkatesh Reddy <sushma.venkatesh.reddy@intel.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34373>
This commit is contained in:
Caio Oliveira 2025-04-03 11:05:28 -07:00
parent 16e3e0d93b
commit bf9ad36f2d

View file

@ -4697,16 +4697,59 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
const brw_reg_type src_type =
brw_type_for_nir_type(devinfo, nir_intrinsic_src_type(instr));
dest = retype(dest, dest_type);
brw_reg src0 = retype(get_nir_src(ntb, instr->src[0], 0), dest_type);
brw_reg src[3] = {};
for (unsigned i = 0; i < ARRAY_SIZE(src); i++) {
nir_src nsrc = instr->src[i];
brw_builder bld16 = bld.exec_all().group(16, 0);
brw_builder bldn = devinfo->ver >= 20 ? bld16 : bld.exec_all().group(8, 0);
if (!nir_src_is_const(nsrc)) {
src[i] = get_nir_src(ntb, nsrc, 0);
continue;
}
bldn.DPAS(dest,
src0,
retype(get_nir_src(ntb, instr->src[2], 0), src_type),
retype(get_nir_src(ntb, instr->src[1], 0), src_type),
/* A single constant value can be used to fill the entire
* cooperative matrix. In this case get_nir_src() would give a
* uniform value (with stride 0), but DPAS can't use regioning,
* it needs the full data available in the register.
*
* So when a source is a constant, allocate the space necessary
* and fill it with the constant value. Except for
*
* When Src0 is specified as null, it is treated as an
* immediate value of +0.
*
* documented in ACM PRM, Vol 2a, "Dot Product Accumulate Systolic".
*/
const unsigned num_components = nir_src_num_components(nsrc);
const unsigned bit_size = nir_src_bit_size(nsrc);
const nir_const_value *nval = nir_src_as_const_value(instr->src[0]);
assert(bit_size <= 32);
for (unsigned j = 1; j < num_components; j++)
assert(nval[0].u32 == nval[j].u32);
uint32_t val = nval[0].u32;
if (i == 0 && val == 0) {
src[i] = brw_null_reg();
} else {
unsigned size = bit_size * num_components;
unsigned count = size / 32;
assert(size % 32 == 0);
src[i] = bld.vgrf(BRW_TYPE_UD, count);
for (unsigned j = 0; j < count; j++)
bld.exec_all().MOV(offset(src[i], bld, j), brw_imm_ud(val));
}
}
const unsigned dpas_exec_size = devinfo->ver >= 20 ? 16 : 8;
brw_builder bldn = bld.exec_all().group(dpas_exec_size, 0);
/* DPAS uses a different source order: Accumulator, B, A. */
bldn.DPAS(retype(dest, dest_type),
retype(src[0], dest_type),
retype(src[2], src_type),
retype(src[1], src_type),
sdepth,
rcount)
->saturate = nir_intrinsic_saturate(instr);