nir: introduce "inexact associative" property

nothing currently uses the associative flag, but they will change soon. we need
to stop incorrectly marking fmul/fadd/etc as associative, because they're not,
but they almost are. distinguish these properties so we can correctly
handle floating point rules without any opcode-based special casing.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36257>
This commit is contained in:
Alyssa Rosenzweig 2025-07-15 14:21:45 -04:00 committed by Marge Bot
parent 421d0e0953
commit e466b8735b
2 changed files with 12 additions and 4 deletions

View file

@ -1435,6 +1435,13 @@ typedef enum {
* comparison.
*/
NIR_OP_IS_SELECTION = (1 << 2),
/**
* Operation is associative mathematically (as real numbers), but not
* associative with floating-point math. This can be treated as associative
* iff the operation's exact bit is not set.
*/
NIR_OP_IS_INEXACT_ASSOCIATIVE = (1 << 3),
} nir_op_algebraic_property;
/* vec16 is the widest ALU op in NIR, making the max number of input of ALU

View file

@ -149,6 +149,7 @@ def type_base_type(type_):
# sources.
_2src_commutative = "2src_commutative "
associative = "associative "
inexact_associative = "inexact_associative "
selection = "selection "
# global dictionary of opcodes
@ -610,7 +611,7 @@ def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
binop_reduce("b32" + name[1:], output_size, tbool32, src_type,
prereduce_expr, reduce_expr, final_expr, description)
binop("fadd", tfloat, _2src_commutative + associative,"""
binop("fadd", tfloat, _2src_commutative + inexact_associative,"""
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
dst = _mesa_double_add_rtz(src0, src1);
@ -652,7 +653,7 @@ binop_convert("uabs_isub", tuint, tint, "", """
""")
binop("uabs_usub", tuint, "", "(src1 > src0) ? (src1 - src0) : (src0 - src1)")
binop("fmul", tfloat, _2src_commutative + associative, """
binop("fmul", tfloat, _2src_commutative + inexact_associative, """
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
if (bit_size == 64)
dst = _mesa_double_mul_rtz(src0, src1);
@ -663,7 +664,7 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
}
""")
binop("fmulz", tfloat32, _2src_commutative + associative, """
binop("fmulz", tfloat32, _2src_commutative + inexact_associative, """
if (src0 == 0.0 || src1 == 0.0)
dst = 0.0;
else if (nir_is_rounding_mode_rtz(execution_mode, 32))
@ -1755,7 +1756,7 @@ opcode("udot_2x16_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
unop_numeric_convert("bf2f", tfloat32, tuint16, "_mesa_bfloat16_bits_to_float(src0)")
unop_numeric_convert("f2bf", tuint16, tfloat32, "_mesa_float_to_bfloat16_bits_rte(src0)")
binop("bfmul", tuint16, _2src_commutative + associative, """
binop("bfmul", tuint16, _2src_commutative + inexact_associative, """
const float a = _mesa_bfloat16_bits_to_float(src0);
const float b = _mesa_bfloat16_bits_to_float(src1);
dst = _mesa_float_to_bfloat16_bits_rte(a * b);