From e466b8735bf05a7466908e40e2b2c73350c5123a Mon Sep 17 00:00:00 2001 From: Alyssa Rosenzweig Date: Tue, 15 Jul 2025 14:21:45 -0400 Subject: [PATCH] nir: introduce "inexact associative" property nothing currently uses the associative flag, but they will change soon. we need to stop incorrectly marking fmul/fadd/etc as associative, because they're not, but they almost are. distinguish these properties so we can correctly handle floating point rules without any opcode-based special casing. Signed-off-by: Alyssa Rosenzweig Reviewed-by: Mel Henning Part-of: --- src/compiler/nir/nir.h | 7 +++++++ src/compiler/nir/nir_opcodes.py | 9 +++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index a746f71ae63..5771930c7e4 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1435,6 +1435,13 @@ typedef enum { * comparison. */ NIR_OP_IS_SELECTION = (1 << 2), + + /** + * Operation is associative mathematically (as real numbers), but not + * associative with floating-point math. This can be treated as associative + * iff the operation's exact bit is not set. + */ + NIR_OP_IS_INEXACT_ASSOCIATIVE = (1 << 3), } nir_op_algebraic_property; /* vec16 is the widest ALU op in NIR, making the max number of input of ALU diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index a0a06439f60..34db1b2bd4e 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -149,6 +149,7 @@ def type_base_type(type_): # sources. _2src_commutative = "2src_commutative " associative = "associative " +inexact_associative = "inexact_associative " selection = "selection " # global dictionary of opcodes @@ -610,7 +611,7 @@ def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr, binop_reduce("b32" + name[1:], output_size, tbool32, src_type, prereduce_expr, reduce_expr, final_expr, description) -binop("fadd", tfloat, _2src_commutative + associative,""" +binop("fadd", tfloat, _2src_commutative + inexact_associative,""" if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { if (bit_size == 64) dst = _mesa_double_add_rtz(src0, src1); @@ -652,7 +653,7 @@ binop_convert("uabs_isub", tuint, tint, "", """ """) binop("uabs_usub", tuint, "", "(src1 > src0) ? (src1 - src0) : (src0 - src1)") -binop("fmul", tfloat, _2src_commutative + associative, """ +binop("fmul", tfloat, _2src_commutative + inexact_associative, """ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { if (bit_size == 64) dst = _mesa_double_mul_rtz(src0, src1); @@ -663,7 +664,7 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { } """) -binop("fmulz", tfloat32, _2src_commutative + associative, """ +binop("fmulz", tfloat32, _2src_commutative + inexact_associative, """ if (src0 == 0.0 || src1 == 0.0) dst = 0.0; else if (nir_is_rounding_mode_rtz(execution_mode, 32)) @@ -1755,7 +1756,7 @@ opcode("udot_2x16_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], unop_numeric_convert("bf2f", tfloat32, tuint16, "_mesa_bfloat16_bits_to_float(src0)") unop_numeric_convert("f2bf", tuint16, tfloat32, "_mesa_float_to_bfloat16_bits_rte(src0)") -binop("bfmul", tuint16, _2src_commutative + associative, """ +binop("bfmul", tuint16, _2src_commutative + inexact_associative, """ const float a = _mesa_bfloat16_bits_to_float(src0); const float b = _mesa_bfloat16_bits_to_float(src1); dst = _mesa_float_to_bfloat16_bits_rte(a * b);