nir: add new float multiply-add opcodes

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41165>
This commit is contained in:
Karol Herbst 2026-04-19 23:13:32 +02:00 committed by Marge Bot
parent 0251be32cb
commit 32e91a7467
4 changed files with 212 additions and 0 deletions

View file

@ -1712,6 +1712,45 @@ nir_alu_instr_channel_used(const nir_alu_instr *instr, unsigned src,
bool
nir_alu_instr_is_comparison(const nir_alu_instr *instr);
static inline bool
nir_alu_instr_is_mul_add(const nir_alu_instr *instr)
{
if (!instr)
return false;
switch (instr->op) {
case nir_op_ffma:
case nir_op_ffma_weak:
case nir_op_fmad:
case nir_op_ffma_old:
return true;
default:
return false;
}
}
static inline bool
nir_alu_instr_is_mul_add_z(const nir_alu_instr *instr)
{
if (!instr)
return false;
switch (instr->op) {
case nir_op_ffmaz:
case nir_op_fmadz:
case nir_op_ffmaz_old:
return true;
default:
return false;
}
}
static inline bool
nir_alu_instr_is_any_mul_add(const nir_alu_instr *alu)
{
return nir_alu_instr_is_mul_add(alu) || nir_alu_instr_is_mul_add_z(alu);
}
bool nir_const_value_negative_equal(nir_const_value c1, nir_const_value c2,
nir_alu_type full_type);
@ -7282,6 +7321,37 @@ nir_is_io_compact(nir_shader *nir, bool is_output, unsigned location)
(nir->info.stage != MESA_SHADER_MESH && location == VARYING_SLOT_TESS_LEVEL_INNER));
}
static inline nir_float_muladd_support
nir_float_muladd_for_bitsize(const nir_shader *nir, unsigned bit_size)
{
switch (bit_size) {
case 16:
return nir->options->float_mul_add16;
case 32:
return nir->options->float_mul_add32;
case 64:
return nir->options->float_mul_add64;
default:
UNREACHABLE("unsupported bit_size");
return (nir_float_muladd_support)0;
}
}
static inline bool
nir_has_ffma(const nir_shader *nir, unsigned bit_size)
{
nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size);
return (muladd & nir_float_muladd_support_has_ffma) != 0;
}
static inline bool
nir_prefers_fmad(const nir_shader *nir, unsigned bit_size)
{
nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size);
return (muladd & nir_float_muladd_support_prefers_split) != 0 ||
(muladd & nir_float_muladd_support_has_ffma) == 0;
}
#ifdef __cplusplus
} /* extern "C" */
#endif

View file

@ -100,6 +100,33 @@ get_float_source(nir_const_value value, unsigned execution_mode, unsigned bit_si
return nir_const_value_as_float(value, bit_size);
}
/**
* Properly rounds and handles denorms for intermediate results. Useful for
* fused opcodes that want to behave exactly like the unfused variants, e.g.
* fmad.
*/
static double
handle_intermediate_float_result(double value, unsigned execution_mode, unsigned bit_size)
{
nir_const_value const_val;
switch(bit_size) {
case 64:
const_val.f64 = value;
break;
case 32:
const_val.f32 = value;
break;
case 16:
if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
const_val.u16 = _mesa_float_to_float16_rtz(value);
} else {
const_val.u16 = _mesa_float_to_float16_rtne(value);
}
}
return get_float_source(const_val, execution_mode, bit_size);
}
/**
* Evaluate one component of packSnorm4x8.
*/

View file

@ -1132,6 +1132,66 @@ def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr,
[src1_size, src2_size, src3_size],
[tuint, tuint, tuint], False, "", const_expr, description)
triop("fmad", tfloat, _2src_commutative, """
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
dst = _mesa_double_mul_rtz(src0, src1);
else
dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
} else {
dst = src0 * src1;
}
dst = handle_intermediate_float_result(dst, execution_mode, bit_size);
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
dst = _mesa_double_add_rtz(dst, src2);
else
dst = _mesa_double_to_float_rtz((double)dst + (double)src2);
} else {
dst = dst + src2;
}
""", description = "Floating-point unfused multiple-add")
triop("ffma", tfloat, _2src_commutative, """
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
dst = _mesa_double_fma_rtz(src0, src1, src2);
else if (bit_size == 32)
dst = _mesa_float_fma_rtz(src0, src1, src2);
else
dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
} else {
if (bit_size == 32)
dst = fmaf(src0, src1, src2);
else
dst = fma(src0, src1, src2);
}
""", description = "Floating-point fused multiple-add")
triop("ffma_weak", tfloat, _2src_commutative, """
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
dst = _mesa_double_fma_rtz(src0, src1, src2);
else if (bit_size == 32)
dst = _mesa_float_fma_rtz(src0, src1, src2);
else
dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
} else {
if (bit_size == 32)
dst = fmaf(src0, src1, src2);
else
dst = fma(src0, src1, src2);
}
""", description = """
Floating-point multiple-add that can eitehr be unfused or fused.
Precise ``ffma_weak`` are required to be either fused or unfused across all
shaders and shader stages where inprecise ``ffma_weak`` doesn't have to remain
consistent not even within the same shader.
This is like GLSLs ``ffma``.
""")
triop("ffma_old", tfloat, _2src_commutative, """
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
@ -1148,6 +1208,44 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
}
""")
triop("fmadz", tfloat32, _2src_commutative, """
if (src0 == 0.0 || src1 == 0.0) {
dst = 0.0 + src2;
} else {
if (nir_is_rounding_mode_rtz(execution_mode, 32))
dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
else
dst = (src0 * src1);
dst = handle_intermediate_float_result(dst, execution_mode, bit_size);
if (nir_is_rounding_mode_rtz(execution_mode, 32))
dst = _mesa_double_to_float_rtz((double)dst + (double)src2);
dst = dst + src2;
}
""", description = """
Floating-point unfused multiply-add with modified zero handling.
Unlike :nir:alu-op:`fmad`, anything (even infinity or NaN) multiplied by +/-0.0 is
+0.0. ``fmadz(0.0, inf, src2)`` and ``fmadz(0.0, nan, src2)`` must be
``+0.0 + src2``.
""")
triop("ffmaz", tfloat32, _2src_commutative, """
if (src0 == 0.0 || src1 == 0.0)
dst = 0.0 + src2;
else if (nir_is_rounding_mode_rtz(execution_mode, 32))
dst = _mesa_float_fma_rtz(src0, src1, src2);
else
dst = fmaf(src0, src1, src2);
""", description = """
Floating-point fused multiply-add with modified zero handling.
Unlike :nir:alu-op:`ffma`, anything (even infinity or NaN) multiplied by +/-0.0 is
+0.0. ``ffmaz(0.0, inf, src2)`` and ``ffmaz(0.0, nan, src2)`` must be
``+0.0 + src2``.
""")
triop("ffmaz_old", tfloat32, _2src_commutative, """
if (src0 == 0.0 || src1 == 0.0)
dst = 0.0 + src2;

View file

@ -259,6 +259,20 @@ typedef enum {
nir_frag_coord_use_pixel_coord = BITFIELD_BIT(2),
} nir_frag_coord_form;
typedef enum {
nir_float_muladd_support_has_ffma = 0x01,
nir_float_muladd_support_has_fmad = 0x02,
/** Strongly hints that fmad or fmul+fadd is preferred over ffma */
nir_float_muladd_support_prefers_split = 0x04,
/** ffma_weak won't be lowered */
nir_float_muladd_support_keep_weak_ffma = 0x08,
nir_float_muladd_support_fuse = 0x10,
} nir_float_muladd_support;
MESA_DEFINE_CPP_ENUM_BITFIELD_OPERATORS(nir_float_muladd_support)
typedef struct nir_shader_compiler_options {
bool lower_fdiv;
bool lower_ffma16;
@ -267,6 +281,9 @@ typedef struct nir_shader_compiler_options {
bool fuse_ffma16;
bool fuse_ffma32;
bool fuse_ffma64;
nir_float_muladd_support float_mul_add16;
nir_float_muladd_support float_mul_add32;
nir_float_muladd_support float_mul_add64;
bool lower_flrp16;
bool lower_flrp32;
/** Lowers flrp when it does not support doubles */