diff --git a/src/gallium/drivers/i915/i915_fpc_nir.c b/src/gallium/drivers/i915/i915_fpc_nir.c index 99f1b664142..d1835800413 100644 --- a/src/gallium/drivers/i915/i915_fpc_nir.c +++ b/src/gallium/drivers/i915/i915_fpc_nir.c @@ -622,11 +622,12 @@ emit_alu(struct nir_to_i915 *c, nir_alu_instr *alu) } } uint32_t hw_op = 0; - bool can_fuse = (vec_arg >= 0 && nargs == 2); + bool can_fuse = (vec_arg >= 0); if (can_fuse) { switch (consumer->op) { case nir_op_fmul: hw_op = A0_MUL; break; case nir_op_fadd: hw_op = A0_ADD; break; + case nir_op_ffma: hw_op = A0_MAD; break; case nir_op_fmin: case nir_op_imin: case nir_op_umin: hw_op = A0_MIN; break; case nir_op_fmax: case nir_op_imax: case nir_op_umax: @@ -634,13 +635,22 @@ emit_alu(struct nir_to_i915 *c, nir_alu_instr *alu) default: can_fuse = false; break; } } - /* check the non-vec source is a single register */ + /* check the non-vec sources are single registers */ + uint32_t other_srcs[3] = { 0, 0, 0 }; + if (can_fuse) { + for (unsigned a = 0; a < nargs; a++) { + if ((int)a == vec_arg) + continue; + nir_def *od = consumer->src[a].src.ssa; + if (od->index >= c->ureg_map_size || + c->ureg_map[od->index] == UREG_BAD) { + can_fuse = false; + break; + } + other_srcs[a] = alu_src_ureg(c, &consumer->src[a]); + } + } if (can_fuse) { - int other_arg = 1 - vec_arg; - nir_def *other_def = consumer->src[other_arg].src.ssa; - if (other_def->index < c->ureg_map_size && - c->ureg_map[other_def->index] != UREG_BAD) { - uint32_t other = alu_src_ureg(c, &consumer->src[other_arg]); nir_def *cdef = &consumer->def; uint32_t cdest = dest; uint32_t cmask = def_mask(cdef); @@ -669,21 +679,20 @@ emit_alu(struct nir_to_i915 *c, nir_alu_instr *alu) uint32_t fused_src = negate( swizzle(base, ch[0], ch[1], ch[2], ch[3]), ng[0], ng[1], ng[2], ng[3]); - if (vec_arg == 0) - i915_emit_arith(p, hw_op, cdest, - group_mask & cmask, 0, - fused_src, other, 0); - else - i915_emit_arith(p, hw_op, cdest, - group_mask & cmask, 0, - other, fused_src, 0); + uint32_t s[3]; + for (unsigned a = 0; a < nargs; a++) + s[a] = ((int)a == vec_arg) ? fused_src + : other_srcs[a]; + i915_emit_arith(p, hw_op, cdest, + group_mask & cmask, 0, + s[0], nargs > 1 ? s[1] : 0, + nargs > 2 ? s[2] : 0); emitted[i] = true; } set_ureg(c, cdef, cdest); c->def_csr[cdef->index] = p->csr - 3; break; - } } } }