nir: make lowering use new ffma opcodes

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41165>
This commit is contained in:
Karol Herbst 2026-04-23 03:44:32 +02:00 committed by Marge Bot
parent 109d93dd98
commit e1aaaf4ed0
10 changed files with 36 additions and 59 deletions

View file

@ -41,10 +41,10 @@ nir_cross3(nir_builder *b, nir_def *x, nir_def *y)
unsigned yzx[3] = { 1, 2, 0 };
unsigned zxy[3] = { 2, 0, 1 };
return nir_ffma_old(b, nir_swizzle(b, x, yzx, 3),
nir_swizzle(b, y, zxy, 3),
nir_fneg(b, nir_fmul(b, nir_swizzle(b, x, zxy, 3),
nir_swizzle(b, y, yzx, 3))));
return nir_ffma_weak(b, nir_swizzle(b, x, yzx, 3),
nir_swizzle(b, y, zxy, 3),
nir_fneg(b, nir_fmul(b, nir_swizzle(b, x, zxy, 3),
nir_swizzle(b, y, yzx, 3))));
}
nir_def *
@ -285,7 +285,7 @@ nir_atan(nir_builder *b, nir_def *y_over_x)
nir_imm_floatN_t(b, -M_PI_2, bit_size));
/* multiply through by x while fixing up the range reduction */
nir_def *tmp = nir_ffma_old(b, nir_fabs(b, u), res, bias);
nir_def *tmp = nir_ffma_weak(b, nir_fabs(b, u), res, bias);
/* sign fixup */
return nir_copysign(b, tmp, y_over_x);

View file

@ -125,20 +125,6 @@ lower_reduction(nir_alu_instr *alu, nir_op chan_op, nir_op merge_op,
return last;
}
static inline bool
will_lower_ffma(nir_shader *shader, unsigned bit_size)
{
switch (bit_size) {
case 16:
return shader->options->lower_ffma16;
case 32:
return shader->options->lower_ffma32;
case 64:
return shader->options->lower_ffma64;
}
UNREACHABLE("bad bit size");
}
static nir_def *
lower_bfdot_to_bfdot2_bfadd(nir_builder *b, nir_alu_instr *alu)
{
@ -183,12 +169,12 @@ lower_fdot(nir_alu_instr *alu, nir_builder *builder, bool is_bfloat16)
/* If we don't want to lower ffma, create several ffma instead of fmul+fadd
* and fusing later because fusing is not possible for exact fdot instructions.
*/
if (!is_bfloat16 && will_lower_ffma(builder->shader, alu->def.bit_size))
if (!is_bfloat16 && nir_prefers_fmad(builder->shader, alu->def.bit_size))
return lower_reduction(alu, nir_op_fmul, nir_op_fadd, builder, reverse_order);
unsigned num_components = nir_op_infos[alu->op].input_sizes[0];
const nir_op fma_op = is_bfloat16 ? nir_op_bffma : nir_op_ffma_old;
const nir_op fma_op = is_bfloat16 ? nir_op_bffma : nir_op_ffma_weak;
const nir_op mul_op = is_bfloat16 ? nir_op_bfmul : nir_op_fmul;
nir_def *prev = NULL;
@ -315,7 +301,7 @@ lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
/* Only use reverse order for imprecise fdph, see explanation in lower_fdot. */
bool reverse_order = !(b->fp_math_ctrl & nir_fp_exact);
if (will_lower_ffma(b->shader, alu->def.bit_size)) {
if (nir_prefers_fmad(b->shader, alu->def.bit_size)) {
nir_def *sum[4];
for (unsigned i = 0; i < 3; i++) {
int dest = reverse_order ? 3 - i : i;
@ -328,12 +314,12 @@ lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
} else if (reverse_order) {
nir_def *sum = nir_channel(b, src1_vec, 3);
for (int i = 2; i >= 0; i--)
sum = nir_ffma_old(b, nir_channel(b, src0_vec, i), nir_channel(b, src1_vec, i), sum);
sum = nir_ffma_weak(b, nir_channel(b, src0_vec, i), nir_channel(b, src1_vec, i), sum);
return sum;
} else {
nir_def *sum = nir_fmul(b, nir_channel(b, src0_vec, 0), nir_channel(b, src1_vec, 0));
sum = nir_ffma_old(b, nir_channel(b, src0_vec, 1), nir_channel(b, src1_vec, 1), sum);
sum = nir_ffma_old(b, nir_channel(b, src0_vec, 2), nir_channel(b, src1_vec, 2), sum);
sum = nir_ffma_weak(b, nir_channel(b, src0_vec, 1), nir_channel(b, src1_vec, 1), sum);
sum = nir_ffma_weak(b, nir_channel(b, src0_vec, 2), nir_channel(b, src1_vec, 2), sum);
return nir_fadd(b, sum, nir_channel(b, src1_vec, 3));
}
}

View file

@ -52,8 +52,8 @@ replace_with_strict_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
nir_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
nir_def *const neg_a = nir_fneg(bld, a);
nir_def *const inner_ffma = nir_ffma_old(bld, neg_a, c, a);
nir_def *const outer_ffma = nir_ffma_old(bld, b, c, inner_ffma);
nir_def *const inner_ffma = nir_ffma_weak(bld, neg_a, c, a);
nir_def *const outer_ffma = nir_ffma_weak(bld, b, c, inner_ffma);
nir_def_rewrite_uses(&alu->def, outer_ffma);
@ -79,7 +79,7 @@ replace_with_single_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
nir_def *const one_minus_c =
nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
nir_def *const b_times_c = nir_fmul(bld, b, c);
nir_def *const final_ffma = nir_ffma_old(bld, a, one_minus_c, b_times_c);
nir_def *const final_ffma = nir_ffma_weak(bld, a, one_minus_c, b_times_c);
nir_def_rewrite_uses(&alu->def, final_ffma);
@ -331,17 +331,8 @@ convert_flrp_instruction(nir_builder *bld,
nir_alu_instr *alu,
bool always_precise)
{
bool have_ffma = false;
unsigned bit_size = alu->def.bit_size;
if (bit_size == 16)
have_ffma = !bld->shader->options->lower_ffma16;
else if (bit_size == 32)
have_ffma = !bld->shader->options->lower_ffma32;
else if (bit_size == 64)
have_ffma = !bld->shader->options->lower_ffma64;
else
UNREACHABLE("invalid bit_size");
bool have_ffma = !nir_prefers_fmad(bld->shader, bit_size);
bld->cursor = nir_before_instr(&alu->instr);
bld->fp_math_ctrl = alu->fp_math_ctrl;

View file

@ -106,12 +106,12 @@ nir_lower_interpolation_instr(nir_builder *b, nir_instr *instr, void *cb_data)
nir_def *bary = intr->src[0].ssa;
nir_def *val;
val = nir_ffma_old(b, nir_channel(b, bary, 1),
nir_channel(b, iid, 1),
nir_channel(b, iid, 0));
val = nir_ffma_old(b, nir_channel(b, bary, 0),
nir_channel(b, iid, 2),
val);
val = nir_ffma_weak(b, nir_channel(b, bary, 1),
nir_channel(b, iid, 1),
nir_channel(b, iid, 0));
val = nir_ffma_weak(b, nir_channel(b, bary, 0),
nir_channel(b, iid, 2),
val);
comps[i] = val;
}

View file

@ -75,11 +75,11 @@ lower_load_pointcoord(lower_pntc_ytransform_state *state,
nir_def *pntc = &intr->def;
nir_def *transform = get_pntc_transform(state);
nir_def *flipped_y = nir_ffma_old(b, nir_channel(b, pntc, y_swizzle),
/* Flip the sign of y if we're flipping. */
nir_channel(b, transform, 0),
/* The offset is 1 if we're flipping, 0 otherwise. */
nir_channel(b, transform, 1));
nir_def *flipped_y = nir_ffma_weak(b, nir_channel(b, pntc, y_swizzle),
/* Flip the sign of y if we're flipping. */
nir_channel(b, transform, 0),
/* The offset is 1 if we're flipping, 0 otherwise. */
nir_channel(b, transform, 1));
/* Reassemble the vector. */
pntc = nir_vector_insert_imm(b, pntc, flipped_y, y_swizzle);

View file

@ -409,7 +409,7 @@ convert_yuv_to_rgb(nir_builder *b, nir_tex_instr *tex,
}
nir_def *result =
nir_ffma_old(b, y, m0, nir_ffma_old(b, u, m1, nir_ffma_old(b, v, m2, offset)));
nir_ffma_weak(b, y, m0, nir_ffma_weak(b, u, m1, nir_ffma_weak(b, v, m2, offset)));
nir_def_rewrite_uses(&tex->def, result);
}

View file

@ -106,8 +106,8 @@ emit_wpos_adjustment(lower_wpos_ytransform_state *state,
*/
unsigned base = invert ? 0 : 2;
/* wpos.y = wpos.y * trans.x/z + trans.y/w */
wpos[1] = nir_ffma_old(b, wpos[1], nir_channel(b, wpostrans, base),
nir_channel(b, wpostrans, base + 1));
wpos[1] = nir_ffma_weak(b, wpos[1], nir_channel(b, wpostrans, base),
nir_channel(b, wpostrans, base + 1));
}
nir_def *new_wpos = nir_vec(b, &wpos[c], intr->num_components);
@ -258,8 +258,8 @@ lower_load_sample_pos(lower_wpos_ytransform_state *state,
nir_def *scale = nir_channel(b, wpostrans, 0);
nir_def *neg_scale = nir_channel(b, wpostrans, 2);
/* Either y or 1-y for scale equal to 1 or -1 respectively. */
nir_def *flipped_y = nir_ffma_old(b, nir_channel(b, pos, 1), scale,
nir_fmax(b, neg_scale, nir_imm_float(b, 0.0)));
nir_def *flipped_y = nir_ffma_weak(b, nir_channel(b, pos, 1), scale,
nir_fmax(b, neg_scale, nir_imm_float(b, 0.0)));
nir_def *flipped_pos = nir_vector_insert_imm(b, pos, flipped_y, 1);
nir_def_rewrite_uses_after(&intr->def, flipped_pos);

View file

@ -101,7 +101,7 @@ denorm_ftz_64 = 'nir_is_denorm_flush_to_zero(info->float_controls_execution_mode
def lowered_sincos(c):
x = ('fsub', ('fmul', 2.0, ('ffract', ('fadd', ('fmul', 0.5 / pi, a), c))), 1.0)
x = ('fmul', ('fsub', x, ('fmul', x, ('fabs', x))), 4.0)
return ('ffma_old', ('ffma_old', x, ('fabs', x), ('fneg', x)), 0.225, x)
return ('ffma_weak', ('ffma_weak', x, ('fabs', x), ('fneg', x)), 0.225, x)
def intBitsToFloat(i):
return struct.unpack('!f', struct.pack('!I', i))[0]

View file

@ -3920,9 +3920,9 @@ try_move_postdominator(struct linkage_info *linkage,
defs[i] = nir_fmul(b, new_tes_loads[i],
nir_channel(b, tesscoord, remap[i]));
} else {
defs[i] = nir_ffma_old(b, new_tes_loads[i],
nir_channel(b, tesscoord, remap[i]),
defs[i - 1]);
defs[i] = nir_ffma_weak(b, new_tes_loads[i],
nir_channel(b, tesscoord, remap[i]),
defs[i - 1]);
}
}
new_input = defs[2];

View file

@ -430,7 +430,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
nir_fmul(nb, eta, nir_a_minus_bc(nb, one, n_dot_i, n_dot_i)));
nir_def *result =
nir_a_minus_bc(nb, nir_fmul(nb, eta, I),
nir_ffma_old(nb, eta, n_dot_i, nir_fsqrt(nb, k)),
nir_ffma_weak(nb, eta, n_dot_i, nir_fsqrt(nb, k)),
N);
/* XXX: bcsel, or if statement? */
dest->def = nir_bcsel(nb, nir_flt(nb, k, zero), zero, result);