i915/corm: extend ALU consumer fusion to ffma and 3-input ops

Generalize the binary ALU consumer fusion to handle ffma (MAD) and
any number of inputs. When a vec's only consumer is an ALU op where
the vec occupies one source slot and all other sources are single
registers, emit the ALU op per register group with partial
writemasks.

shader-db (I915_FS=nir): 252/403 compiled, 3618 alu
shader-db (I915_FS=both): nir won 252 (26 identical, 1 tied, 219 better, 6 only),
  38 TGSI, 113 neither

Assisted-by: Claude
This commit is contained in:
Adam Jackson 2026-05-07 12:24:44 -04:00
parent 2a2ef36852
commit bfbba3f3b4

View file

@ -622,11 +622,12 @@ emit_alu(struct nir_to_i915 *c, nir_alu_instr *alu)
} }
} }
uint32_t hw_op = 0; uint32_t hw_op = 0;
bool can_fuse = (vec_arg >= 0 && nargs == 2); bool can_fuse = (vec_arg >= 0);
if (can_fuse) { if (can_fuse) {
switch (consumer->op) { switch (consumer->op) {
case nir_op_fmul: hw_op = A0_MUL; break; case nir_op_fmul: hw_op = A0_MUL; break;
case nir_op_fadd: hw_op = A0_ADD; break; case nir_op_fadd: hw_op = A0_ADD; break;
case nir_op_ffma: hw_op = A0_MAD; break;
case nir_op_fmin: case nir_op_imin: case nir_op_umin: case nir_op_fmin: case nir_op_imin: case nir_op_umin:
hw_op = A0_MIN; break; hw_op = A0_MIN; break;
case nir_op_fmax: case nir_op_imax: case nir_op_umax: case nir_op_fmax: case nir_op_imax: case nir_op_umax:
@ -634,13 +635,22 @@ emit_alu(struct nir_to_i915 *c, nir_alu_instr *alu)
default: can_fuse = false; break; default: can_fuse = false; break;
} }
} }
/* check the non-vec source is a single register */ /* check the non-vec sources are single registers */
uint32_t other_srcs[3] = { 0, 0, 0 };
if (can_fuse) {
for (unsigned a = 0; a < nargs; a++) {
if ((int)a == vec_arg)
continue;
nir_def *od = consumer->src[a].src.ssa;
if (od->index >= c->ureg_map_size ||
c->ureg_map[od->index] == UREG_BAD) {
can_fuse = false;
break;
}
other_srcs[a] = alu_src_ureg(c, &consumer->src[a]);
}
}
if (can_fuse) { if (can_fuse) {
int other_arg = 1 - vec_arg;
nir_def *other_def = consumer->src[other_arg].src.ssa;
if (other_def->index < c->ureg_map_size &&
c->ureg_map[other_def->index] != UREG_BAD) {
uint32_t other = alu_src_ureg(c, &consumer->src[other_arg]);
nir_def *cdef = &consumer->def; nir_def *cdef = &consumer->def;
uint32_t cdest = dest; uint32_t cdest = dest;
uint32_t cmask = def_mask(cdef); uint32_t cmask = def_mask(cdef);
@ -669,21 +679,20 @@ emit_alu(struct nir_to_i915 *c, nir_alu_instr *alu)
uint32_t fused_src = negate( uint32_t fused_src = negate(
swizzle(base, ch[0], ch[1], ch[2], ch[3]), swizzle(base, ch[0], ch[1], ch[2], ch[3]),
ng[0], ng[1], ng[2], ng[3]); ng[0], ng[1], ng[2], ng[3]);
if (vec_arg == 0) uint32_t s[3];
i915_emit_arith(p, hw_op, cdest, for (unsigned a = 0; a < nargs; a++)
group_mask & cmask, 0, s[a] = ((int)a == vec_arg) ? fused_src
fused_src, other, 0); : other_srcs[a];
else i915_emit_arith(p, hw_op, cdest,
i915_emit_arith(p, hw_op, cdest, group_mask & cmask, 0,
group_mask & cmask, 0, s[0], nargs > 1 ? s[1] : 0,
other, fused_src, 0); nargs > 2 ? s[2] : 0);
emitted[i] = true; emitted[i] = true;
} }
set_ureg(c, cdef, cdest); set_ureg(c, cdef, cdest);
c->def_csr[cdef->index] = p->csr - 3; c->def_csr[cdef->index] = p->csr - 3;
break; break;
}
} }
} }
} }