nir/search: gather union of all fp_math_ctrl

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39616>
This commit is contained in:
Georg Lehmann 2026-01-28 15:35:49 +01:00 committed by Marge Bot
parent 3275be503c
commit 575affaf48
4 changed files with 46 additions and 44 deletions

View file

@ -1490,14 +1490,13 @@ nir_op_is_selection(nir_op op)
{
return (nir_op_infos[op].algebraic_properties & NIR_OP_IS_SELECTION) != 0;
}
#define NIR_FP_MATH_CONTROL_BIT_COUNT 4
/**
* Floating point fast math control.
*
* All new bits must restrict optimizations when they are set, not when they
* are missing. This means a bitwise OR always produces a no less restrictive set.
*
* See also nir_alu_instr::exact, which should (and hopefully will be) moved
* to this enum in the future.
*/
typedef enum {
/**
@ -1538,7 +1537,7 @@ typedef enum {
nir_fp_preserve_nan,
nir_fp_fast_math = 0,
nir_fp_no_fast_math = BITFIELD_MASK(4),
nir_fp_no_fast_math = BITFIELD_MASK(NIR_FP_MATH_CONTROL_BIT_COUNT),
} nir_fp_math_control;
/***/
@ -1569,7 +1568,7 @@ typedef struct nir_alu_instr {
* that have no float_controls2 equivalent (rounding mode and denorm handling)
* remain in the execution mode only.
*/
uint32_t fp_math_ctrl : 4;
uint32_t fp_math_ctrl : NIR_FP_MATH_CONTROL_BIT_COUNT;
/** Sources
*

View file

@ -225,13 +225,9 @@ class Value(object):
${val.cond_index},
${val.swizzle()},
% elif isinstance(val, Expression):
${'true' if val.inexact else 'false'},
${val.fp_math_ctrl_exclude()},
${'true' if val.exact else 'false'},
${'true' if val.ignore_exact else 'false'},
${'true' if val.nsz else 'false'},
${'true' if val.nnan else 'false'},
${'true' if val.ninf else 'false'},
${'true' if val.contract else 'false'},
${'true' if len(val.sources) > 1 and isinstance(val.sources[1], Constant) else 'false'},
${val.swizzle},
${val.c_opcode()},
@ -429,6 +425,7 @@ class Expression(Value):
self.nnan = cond.pop('nnan', False)
self.ninf = cond.pop('ninf', False)
self.contract = cond.pop('contract', False)
# Single component index of the swizzle of the output of this
# expression, or -1 if no swizzle (all components)
self.swizzle = - \
@ -510,6 +507,31 @@ class Expression(Value):
srcs = "".join(src.render(cache) for src in self.sources)
return srcs + super(Expression, self).render(cache)
def fp_math_ctrl_exclude(self):
exclude = set()
if self.inexact:
exclude.add("nir_fp_exact")
exclude.add("nir_fp_preserve_signed_zero")
exclude.add("nir_fp_preserve_inf")
exclude.add("nir_fp_preserve_nan")
if self.contract:
exclude.add("nir_fp_exact")
if self.nsz:
exclude.add("nir_fp_preserve_signed_zero")
if self.ninf:
exclude.add("nir_fp_preserve_inf")
if self.nnan:
exclude.add("nir_fp_preserve_nan")
if not exclude:
return "nir_fp_fast_math"
return ' | '.join(sorted(list(exclude)))
class BitSizeValidator(object):
"""A class for validating bit sizes of expressions.

View file

@ -32,7 +32,7 @@
struct match_state {
bool inexact_match;
bool has_exact_alu;
unsigned fp_math_ctrl;
uint8_t comm_op_direction;
unsigned variables_seen;
@ -376,15 +376,6 @@ match_expression(const nir_algebraic_table *table, const nir_search_expression *
if (expr->cond_index != -1 && !table->expression_cond[expr->cond_index](instr))
return false;
if (expr->nsz && nir_alu_instr_is_signed_zero_preserve(instr))
return false;
if (expr->nnan && nir_alu_instr_is_nan_preserve(instr))
return false;
if (expr->ninf && nir_alu_instr_is_inf_preserve(instr))
return false;
if (!nir_op_matches_search_op(instr->op, expr->opcode))
return false;
@ -392,9 +383,14 @@ match_expression(const nir_algebraic_table *table, const nir_search_expression *
instr->def.bit_size != expr->value.bit_size)
return false;
state->inexact_match |= expr->inexact || expr->contract;
state->has_exact_alu |= nir_alu_instr_is_exact(instr) && !expr->ignore_exact;
if (state->inexact_match && state->has_exact_alu)
unsigned fp_math_ctrl = instr->fp_math_ctrl & ~(expr->ignore_exact ? nir_fp_exact : 0);
if (expr->fp_math_ctrl_exclude & fp_math_ctrl)
return false;
state->inexact_match |= expr->fp_math_ctrl_exclude & nir_fp_exact;
state->fp_math_ctrl |= fp_math_ctrl;
if (state->inexact_match && (state->fp_math_ctrl & nir_fp_exact))
return false;
assert(nir_op_infos[instr->op].num_inputs > 0);
@ -486,7 +482,7 @@ construct_value(nir_builder *build,
* replacement should be exact.
*/
alu->fp_math_ctrl = nir_instr_as_alu(instr)->fp_math_ctrl;
if (state->has_exact_alu || expr->exact)
if ((state->fp_math_ctrl & nir_fp_exact) || expr->exact)
alu->fp_math_ctrl |= nir_fp_exact;
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
@ -618,7 +614,7 @@ dump_value(const nir_algebraic_table *table, const nir_search_value *val)
case nir_search_value_expression: {
const nir_search_expression *expr = nir_search_value_as_expression(val);
fprintf(stderr, "(");
if (expr->inexact)
if (expr->fp_math_ctrl_exclude & nir_fp_exact)
fprintf(stderr, "~");
switch (expr->opcode) {
#define CASE(n) \
@ -711,8 +707,8 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
swizzle[i] = i;
struct match_state state;
state.fp_math_ctrl = nir_fp_fast_math;
state.inexact_match = false;
state.has_exact_alu = false;
state.state = search_state;
state.pass_op_table = table->pass_op_table;
state.table = table;
@ -886,15 +882,12 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr,
nir_alu_instr *alu = nir_instr_as_alu(instr);
const bool ignore_inexact = nir_alu_instr_is_signed_zero_inf_nan_preserve(alu);
int xform_idx = *util_dynarray_element(states, uint16_t,
alu->def.index);
for (const struct transform *xform = &table->transforms[table->transform_offsets[xform_idx]];
xform->condition_offset != ~0;
xform++) {
if (condition_flags[xform->condition_offset] &&
!(table->values[xform->search].expression.inexact && ignore_inexact) &&
nir_replace_instr(build, alu, state, states, table,
&table->values[xform->search].expression,
&table->values[xform->replace].value, worklist, dead_instrs)) {

View file

@ -128,10 +128,10 @@ typedef struct {
nir_search_value value;
/* When set on a search expression, the expression will only match an SSA
* value that does *not* have the exact bit set. If unset, the exact bit
* on the SSA value is ignored.
* value that does *not* have these float control bits set. If unset,
* the bits on the instruction are ignored for matching.
*/
bool inexact : 1;
unsigned fp_math_ctrl_exclude : NIR_FP_MATH_CONTROL_BIT_COUNT;
/** In a replacement, requests that the instruction be marked exact. */
bool exact : 1;
@ -139,18 +139,6 @@ typedef struct {
/** Don't make the replacement exact if the search expression is exact. */
bool ignore_exact : 1;
/** Replacement does not preserve signed of zero. */
bool nsz : 1;
/** Replacement does not preserve NaN. */
bool nnan : 1;
/** Replacement does not preserve infinities. */
bool ninf : 1;
/** Replacement contracts an expression */
bool contract : 1;
/** Whether the second source is a nir_search_value_constant */
bool src1_is_const : 1;