From 575affaf487fc23f3d3b98c0f714d1325c8a6fb1 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Wed, 28 Jan 2026 15:35:49 +0100 Subject: [PATCH] nir/search: gather union of all fp_math_ctrl Reviewed-by: Alyssa Rosenzweig Part-of: --- src/compiler/nir/nir.h | 9 ++++----- src/compiler/nir/nir_algebraic.py | 32 ++++++++++++++++++++++++++----- src/compiler/nir/nir_search.c | 31 ++++++++++++------------------ src/compiler/nir/nir_search.h | 18 +++-------------- 4 files changed, 46 insertions(+), 44 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 155d75ad83c..7feffc247a2 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1490,14 +1490,13 @@ nir_op_is_selection(nir_op op) { return (nir_op_infos[op].algebraic_properties & NIR_OP_IS_SELECTION) != 0; } + +#define NIR_FP_MATH_CONTROL_BIT_COUNT 4 /** * Floating point fast math control. * * All new bits must restrict optimizations when they are set, not when they * are missing. This means a bitwise OR always produces a no less restrictive set. - * - * See also nir_alu_instr::exact, which should (and hopefully will be) moved - * to this enum in the future. */ typedef enum { /** @@ -1538,7 +1537,7 @@ typedef enum { nir_fp_preserve_nan, nir_fp_fast_math = 0, - nir_fp_no_fast_math = BITFIELD_MASK(4), + nir_fp_no_fast_math = BITFIELD_MASK(NIR_FP_MATH_CONTROL_BIT_COUNT), } nir_fp_math_control; /***/ @@ -1569,7 +1568,7 @@ typedef struct nir_alu_instr { * that have no float_controls2 equivalent (rounding mode and denorm handling) * remain in the execution mode only. */ - uint32_t fp_math_ctrl : 4; + uint32_t fp_math_ctrl : NIR_FP_MATH_CONTROL_BIT_COUNT; /** Sources * diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index a89b51c1117..ecb7a119615 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -225,13 +225,9 @@ class Value(object): ${val.cond_index}, ${val.swizzle()}, % elif isinstance(val, Expression): - ${'true' if val.inexact else 'false'}, + ${val.fp_math_ctrl_exclude()}, ${'true' if val.exact else 'false'}, ${'true' if val.ignore_exact else 'false'}, - ${'true' if val.nsz else 'false'}, - ${'true' if val.nnan else 'false'}, - ${'true' if val.ninf else 'false'}, - ${'true' if val.contract else 'false'}, ${'true' if len(val.sources) > 1 and isinstance(val.sources[1], Constant) else 'false'}, ${val.swizzle}, ${val.c_opcode()}, @@ -429,6 +425,7 @@ class Expression(Value): self.nnan = cond.pop('nnan', False) self.ninf = cond.pop('ninf', False) self.contract = cond.pop('contract', False) + # Single component index of the swizzle of the output of this # expression, or -1 if no swizzle (all components) self.swizzle = - \ @@ -510,6 +507,31 @@ class Expression(Value): srcs = "".join(src.render(cache) for src in self.sources) return srcs + super(Expression, self).render(cache) + def fp_math_ctrl_exclude(self): + exclude = set() + if self.inexact: + exclude.add("nir_fp_exact") + exclude.add("nir_fp_preserve_signed_zero") + exclude.add("nir_fp_preserve_inf") + exclude.add("nir_fp_preserve_nan") + + if self.contract: + exclude.add("nir_fp_exact") + + if self.nsz: + exclude.add("nir_fp_preserve_signed_zero") + + if self.ninf: + exclude.add("nir_fp_preserve_inf") + + if self.nnan: + exclude.add("nir_fp_preserve_nan") + + if not exclude: + return "nir_fp_fast_math" + + return ' | '.join(sorted(list(exclude))) + class BitSizeValidator(object): """A class for validating bit sizes of expressions. diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 6d2ae292cdb..92f8163435a 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -32,7 +32,7 @@ struct match_state { bool inexact_match; - bool has_exact_alu; + unsigned fp_math_ctrl; uint8_t comm_op_direction; unsigned variables_seen; @@ -376,15 +376,6 @@ match_expression(const nir_algebraic_table *table, const nir_search_expression * if (expr->cond_index != -1 && !table->expression_cond[expr->cond_index](instr)) return false; - if (expr->nsz && nir_alu_instr_is_signed_zero_preserve(instr)) - return false; - - if (expr->nnan && nir_alu_instr_is_nan_preserve(instr)) - return false; - - if (expr->ninf && nir_alu_instr_is_inf_preserve(instr)) - return false; - if (!nir_op_matches_search_op(instr->op, expr->opcode)) return false; @@ -392,9 +383,14 @@ match_expression(const nir_algebraic_table *table, const nir_search_expression * instr->def.bit_size != expr->value.bit_size) return false; - state->inexact_match |= expr->inexact || expr->contract; - state->has_exact_alu |= nir_alu_instr_is_exact(instr) && !expr->ignore_exact; - if (state->inexact_match && state->has_exact_alu) + unsigned fp_math_ctrl = instr->fp_math_ctrl & ~(expr->ignore_exact ? nir_fp_exact : 0); + + if (expr->fp_math_ctrl_exclude & fp_math_ctrl) + return false; + + state->inexact_match |= expr->fp_math_ctrl_exclude & nir_fp_exact; + state->fp_math_ctrl |= fp_math_ctrl; + if (state->inexact_match && (state->fp_math_ctrl & nir_fp_exact)) return false; assert(nir_op_infos[instr->op].num_inputs > 0); @@ -486,7 +482,7 @@ construct_value(nir_builder *build, * replacement should be exact. */ alu->fp_math_ctrl = nir_instr_as_alu(instr)->fp_math_ctrl; - if (state->has_exact_alu || expr->exact) + if ((state->fp_math_ctrl & nir_fp_exact) || expr->exact) alu->fp_math_ctrl |= nir_fp_exact; for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { @@ -618,7 +614,7 @@ dump_value(const nir_algebraic_table *table, const nir_search_value *val) case nir_search_value_expression: { const nir_search_expression *expr = nir_search_value_as_expression(val); fprintf(stderr, "("); - if (expr->inexact) + if (expr->fp_math_ctrl_exclude & nir_fp_exact) fprintf(stderr, "~"); switch (expr->opcode) { #define CASE(n) \ @@ -711,8 +707,8 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, swizzle[i] = i; struct match_state state; + state.fp_math_ctrl = nir_fp_fast_math; state.inexact_match = false; - state.has_exact_alu = false; state.state = search_state; state.pass_op_table = table->pass_op_table; state.table = table; @@ -886,15 +882,12 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr, nir_alu_instr *alu = nir_instr_as_alu(instr); - const bool ignore_inexact = nir_alu_instr_is_signed_zero_inf_nan_preserve(alu); - int xform_idx = *util_dynarray_element(states, uint16_t, alu->def.index); for (const struct transform *xform = &table->transforms[table->transform_offsets[xform_idx]]; xform->condition_offset != ~0; xform++) { if (condition_flags[xform->condition_offset] && - !(table->values[xform->search].expression.inexact && ignore_inexact) && nir_replace_instr(build, alu, state, states, table, &table->values[xform->search].expression, &table->values[xform->replace].value, worklist, dead_instrs)) { diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index 548c1f1da72..eaab496f585 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -128,10 +128,10 @@ typedef struct { nir_search_value value; /* When set on a search expression, the expression will only match an SSA - * value that does *not* have the exact bit set. If unset, the exact bit - * on the SSA value is ignored. + * value that does *not* have these float control bits set. If unset, + * the bits on the instruction are ignored for matching. */ - bool inexact : 1; + unsigned fp_math_ctrl_exclude : NIR_FP_MATH_CONTROL_BIT_COUNT; /** In a replacement, requests that the instruction be marked exact. */ bool exact : 1; @@ -139,18 +139,6 @@ typedef struct { /** Don't make the replacement exact if the search expression is exact. */ bool ignore_exact : 1; - /** Replacement does not preserve signed of zero. */ - bool nsz : 1; - - /** Replacement does not preserve NaN. */ - bool nnan : 1; - - /** Replacement does not preserve infinities. */ - bool ninf : 1; - - /** Replacement contracts an expression */ - bool contract : 1; - /** Whether the second source is a nir_search_value_constant */ bool src1_is_const : 1;