mesa/src/compiler/nir/nir_opt_reassociate.c
Alyssa Rosenzweig 0c49738211 nir/opt_reassociate: fix exactness bug
For an inexact-associative operation (fadd or fmul), can_reassociate ensures the
root of the chain is inexact to allow reassociating. However, build_chain just
checks for opcodes to match up after, although we do sum up exactness across the
chain. Although an Effort Was Made, it still seems incorrect to reassociate

   %3 = fadd! %0, %1
   %4 = fadd %3, %2

to instead be (ex.)

   %3 = fadd! %0, %2
   %4 = fadd! %3, %1

Closes: #14418
Fixes: e0b0f7e73c ("nir: add ALU reassocation pass")
Signed-off-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41162>
2026-04-28 21:14:56 +00:00

598 lines
20 KiB
C

/*
* Copyright 2025 Valve Corporation
* SPDX-License-Identifier: MIT
*/
#include <stdint.h>
#include "util/hash_table.h"
#include "util/list.h"
#include "util/ralloc.h"
#include "util/u_dynarray.h"
#include "util/u_math.h"
#include "util/u_qsort.h"
#include "nir.h"
#include "nir_builder.h"
/* NIR pass to reassociate scalar binary arithmetic.
*
* Before running this pass, isub/fsub should be lowered to iadd/fadd.
* iadd/imin3/imin3/etc should be split into binary operations. If possible, fma
* should be split to fmul/fadd. This maximizes the number of binary operation
* chains the pass can reassociate.
*
* After running this pass, other passes should be run to get the benefit:
* constant folding, CSE, algebraic, nir_opt_preamble, copy prop, DCE, etc.
*
* How does the algorithm work?
*
* We first identify "chains". A chain is a list of (not necessarily unique)
* sources, where a fixed binary operation is repeatedly applied to reduce the
* chain. Each intermediate operation must only be used by its parent. In
* other words, a chain is a linearized expression tree.
*
* If we have the NIR:
*
* %5 = iadd %0, %1
* %6 = iadd %2, %3
* %7 = iadd %5, %6
* %8 = iadd %4, %7
*
* Then (%0, %1, %2, %3, %4) is a length-5 chain rooted at the last iadd.
*
* The sources in each chain are reordered, then we rewrite the program to use
* our selected order. The chosen order affects how effective other
* optimizations are. We therefore use two major heuristics.
*
* The first heuristic is "sort by rank". Rank is traditionally defined as how
* "deep" a definition is in the control flow graph. Constants get rank 0,
* definitions involving 1 level of control flow rank 1, and so on. By
* operating on low rank sources first, we improve our chances of hoisting
* low rank operations. Sort-by-rank therefore promotes constant folding,
* preamble/scalar ALU usage, and loop-invariant code motion.
*
* The second heuristic is the "global CSE" heuristic. Pairs of sources might
* appear in multiple chains. By reordering to perform these common operations
* first, we are able to CSE inner calculations across chains. This is
* especially effective for graphics shaders, which often contain code like:
*
* scale * normalize(v)
*
* ...scalarizing to
*
* inv_magnitude = rsq(dot(v, v))
* scale * (v.x * inv_magnitude)
* scale * (v.y * inv_magnitude)
* scale * (v.z * inv_magnitude)
*
* This scalar code contains three fmul chains:
*
* (scale, v.x, inv_magnitude)
* (scale, v.y, inv_magnitude)
* (scale, v.z, inv_magnitude)
*
* We count the number of appearances of each pair globally:
*
* 3 (scale, inv_magnitude)
* 1 (scale, v.x), (scale, v.y), (scale, v.z)
*
* For each chain, the (scale, inv_magnitude) pair has the highest frequency so
* is performed first, exposing the CSE opportunity:
*
* inv_magnitude = rsq(dot(v, v))
* v.x * (scale * inv_magnitude)
* v.y * (scale * inv_magnitude)
* v.z * (scale * inv_magnitude)
*
* References:
*
* Rank heuristic: https://web.eecs.umich.edu/~mahlke/courses/583f22/lectures/Nov14/group19_paper.pdf
* CSE heuristic: https://reviews.llvm.org/D40049
* LLVM: https://llvm.org/doxygen/Reassociate_8cpp_source.html
* GCC: https://github.com/gcc-mirror/gcc/tree/master/gcc/tree-ssa-reassoc.cc
*/
#define MAX_CHAIN_LENGTH 16
#define PASS_FLAG_INTERIOR (1)
struct pair_key {
/* Def index of each source */
uint32_t index[2];
/* Component of each source */
uint8_t component[2];
/* Operation applied to the pair. Each operation gets a separate abstract
* pair map, concretely implemented by including the opcode in the key.
*
* nir_op, but uint16_t becuase of MSVC.
*/
uint16_t op;
};
static_assert(sizeof(struct pair_key) == 12, "packed");
DERIVE_HASH_TABLE(pair_key);
static struct pair_key
get_pair_key(nir_op op, nir_scalar a, nir_scalar b)
{
/* Normalize pairs for better results, exploiting op's commutativity. */
if ((a.def->index > b.def->index) ||
((a.def->index == b.def->index) && (a.comp > b.comp))) {
SWAP(a, b);
}
return (struct pair_key){
.index = {a.def->index, b.def->index},
.component = {a.comp, b.comp},
.op = op,
};
}
/*
* We record the frequency of pairs in a hash table. As a small optimization, we
* record the frequency (which is always non-zero) in the `data` field directly,
* without an extra indirection.
*/
static void
increment_pair_freq(struct hash_table *ht, struct pair_key key)
{
uint32_t hash = pair_key_hash(&key);
struct hash_entry *ent = _mesa_hash_table_search_pre_hashed(ht, hash, &key);
if (ent) {
ent->data = (void *)(((uintptr_t)ent->data) + 1);
} else {
struct pair_key *clone = ralloc_memdup(ht, &key, sizeof(key));
_mesa_hash_table_insert_pre_hashed(ht, hash, clone, (void *)(uintptr_t)1);
}
}
static unsigned
lookup_pair_freq(struct hash_table *ht, struct pair_key key)
{
return (uintptr_t)(_mesa_hash_table_search(ht, &key)->data);
}
static int
rank(nir_scalar s)
{
/* Constants are rank 0. This promotes constant folding. */
if (nir_scalar_is_const(s))
return 0;
/* Convergent expressions are rank 1, promoting preambles and scalar ALU */
if (!s.def->divergent)
return 1;
/* Everything else is rank 2. TODO: Promote loop-invariant code motion. */
return 2;
}
struct chain {
nir_alu_instr *root;
unsigned length;
nir_scalar srcs[MAX_CHAIN_LENGTH];
bool do_global_cse;
unsigned fp_math_ctrl;
};
UNUSED static void
print_chain(struct chain *c)
{
for (unsigned i = 0; i < c->length; ++i) {
printf("%s%u.%c", i ? ", " : "", c->srcs[i].def->index,
"xyzw"[c->srcs[i].comp]);
}
printf("\n");
}
static bool
can_reassociate(nir_alu_instr *alu)
{
/* By design, we only handle scalar math. */
if (alu->def.num_components != 1)
return false;
/* Check for the relevant algebraic properties. get_pair_key requires
* commutativity. NIR does not currently have non-commutative associative
* ALU operations, although that could change.
*/
nir_op_algebraic_property props = nir_op_infos[alu->op].algebraic_properties;
return (props & NIR_OP_IS_2SRC_COMMUTATIVE) &&
((props & NIR_OP_IS_ASSOCIATIVE) ||
(!nir_alu_instr_no_reassoc(alu) && (props & NIR_OP_IS_INEXACT_ASSOCIATIVE)));
}
/*
* Recursive depth-first-search rooted at a given instruction to build a chain
* of sources. Effectively, this linearizes expression trees. We cap the search
* depth with careful accounting to ensure we do not exceed MAX_CHAIN_LENGTH.
*/
static void
build_chain(struct chain *c, nir_scalar def, unsigned reserved_count)
{
nir_alu_instr *alu = nir_def_as_alu(def.def);
/* Conservative fast math handling: take the union of all float controls
* along the chain. Float controls may be safely added but not removed.
*/
c->fp_math_ctrl |= alu->fp_math_ctrl;
for (unsigned i = 0; i < 2; ++i) {
nir_scalar src = nir_scalar_chase_alu_src(def, i);
unsigned remaining = 1 - i;
unsigned reserved_plus_remaining = reserved_count + remaining;
if (nir_scalar_is_alu(src) && nir_scalar_alu_op(src) == alu->op &&
can_reassociate(nir_def_as_alu(src.def)) &&
list_is_singular(&src.def->uses) &&
c->length + reserved_plus_remaining + 2 <= MAX_CHAIN_LENGTH) {
/* Any interior nodes cannot be the root */
nir_def_instr(src.def)->pass_flags = PASS_FLAG_INTERIOR;
/* Recurse, reserving space for the next sources */
build_chain(c, src, reserved_count + remaining);
} else {
assert(c->length < MAX_CHAIN_LENGTH);
c->srcs[c->length++] = src;
}
}
}
/* Iterate all O(N^2) pairs. Since we don't care about order or self-pairs, we
* start j at (i + 1) to improve runtime.
*/
#define foreach_pair(chain, i, j) \
for (unsigned i = 0; i < (chain)->length; ++i) \
for (unsigned j = i + 1; j < (chain)->length; ++j)
static void
record_pairs(struct chain *c, struct hash_table *pair_freq)
{
struct pair_key keys[MAX_CHAIN_LENGTH * MAX_CHAIN_LENGTH];
unsigned key_count = 0;
foreach_pair(c, i, j) {
struct pair_key key = get_pair_key(c->root->op, c->srcs[i], c->srcs[j]);
bool unique = true;
/* Deduplicate keys within a chain to avoid bias */
for (unsigned k = 0; k < key_count; ++k) {
if (pair_key_equal(&keys[k], &key)) {
unique = false;
break;
}
}
/* Increment for unique keys */
if (unique) {
increment_pair_freq(pair_freq, key);
keys[key_count++] = key;
}
}
}
/*
* Search for chains. To do so efficiently, we walk backwards. NIR's source
* order is compatible with dominance. That guarantees we see roots before
* interior instructions/leaves. When searching at each potential root, we mark
* interior nodes as we go, so we know not to consider them for roots. This
* ensures we do not duplicate chains and keeps `find_chains` O(instructions).
*/
static void
find_chains(nir_function_impl *impl, struct hash_table *pair_freq,
struct util_dynarray *chains)
{
nir_foreach_block_reverse(block, impl) {
nir_foreach_instr_reverse(instr, block) {
if (instr->type != nir_instr_type_alu ||
instr->pass_flags == PASS_FLAG_INTERIOR)
continue;
nir_alu_instr *alu = nir_instr_as_alu(instr);
if (!can_reassociate(alu))
continue;
/* Find the chain rooted at `alu` */
struct chain c = {.root = alu, .length = 0};
build_chain(&c, nir_get_scalar(&alu->def, 0), 0);
/* Record pairs even if we won't reassociate this chain, so we get
* better CSE behaviour globally with other chains.
*/
if (pair_freq && c.length <= 8)
record_pairs(&c, pair_freq);
/* We need at least 3 sources to reassociate anything */
if (c.length < 3)
continue;
/* Analyze the chain to feed our heuristic */
unsigned lowest_rank = UINT32_MAX, nr_lowest = 0;
unsigned highest_rank = 0, nr_highest = 0;
bool local = true;
for (unsigned i = 0; i < c.length; ++i) {
lowest_rank = MIN2(rank(c.srcs[i]), lowest_rank);
highest_rank = MAX2(rank(c.srcs[i]), highest_rank);
local &= nir_def_block(c.srcs[i].def) == block;
}
for (unsigned i = 0; i < c.length; ++i) {
nr_lowest += (rank(c.srcs[i]) == lowest_rank);
nr_highest += (rank(c.srcs[i]) == highest_rank);
}
/* If we don't have the pair_freq table, the caller doesn't want to use
* the global CSE heuristic at all.
*/
c.do_global_cse = pair_freq != NULL;
/* The global CSE heuristic is quadratic-time in the length of the
* chain, because it needs to consider all pairs. We limit that
* heuristic to small chains to keep the worst-case constant-time. Past
* a point, increasing chain lengths has diminishing returns.
*
* Secondarily, this serves to control register pressure. Both
* reassociating chains and CSE itself tend to increase pressure. This
* increase is particularly pronounced for chains spanning a large part
* of the control flow graph. Therefore, we allow longer chains for
* local chains (where all instructions are in a single basic block)
* rather than cross-block chains. This trades off instruction count
* and register pressure, and probably needs to be tuned.
*/
c.do_global_cse &= c.length <= (local ? 8 : 3);
/* The heuristic targeting global CSE can interfere with preamble
* forming, where sort-by-rank excels. For chains where all sources
* have the same rank except 1, we disable the CSE heuristic and
* instead sort-by-rank. This is itself a heuristic.
*
* As a concrete example, consider the code:
*
* out1 = input1 + uniform1 + uniform2
* out2 = input1 + uniform1 + uniform3
*
* The global CSE heuristic will associate this code as:
*
* out1 = (input1 + uniform1) + uniform2
* out2 = (input1 + uniform1) + uniform3
*
* This lets us delete 1 addition by CSE'ing the first add. However,
* it prevents us from hoisting anything to the preamble, because the
* result of that CSE'd addition is not uniform.
*
* Sort-by-rank instead associates the code:
*
* out1 = input1 + (uniform1 + uniform2)
* out2 = input1 + (uniform1 + uniform3)
*
* Both uniform-uniform adds get hoisted to the preamble. For the main
* shader, this is a net reduction in 1 add.
*
* For hardware with scalar ALUs but no preambles: the first version
* costs 3 VALU, the second version costs 2 VALU + 2 SALU. Since SALU
* is usually underused, that may be a win.
*
* For hardware that doesn't have either, this heuristic only affects
* constants. Enabling constant folding here is a strict win.
*/
c.do_global_cse &= nr_lowest != (c.length - 1);
/* If all the ranks are the same, sort-by-rank is pointless */
bool sort_by_rank = nr_lowest != c.length;
/* If all ranks are maximal except one, sort-by-rank is unlikely to
* help much. This is a chain like "scalar + vector + vector", which is
* 2 vector adds no matter where we put the scalar. Reassociating such
* a chain is likely to increase register pressure without improving
* instruction count, so bail. This is a heuristic tradeoff.
*/
sort_by_rank &= nr_highest != (c.length - 1);
/* Reassociate the chain if one of our heuristics can improve it */
if (sort_by_rank || c.do_global_cse) {
util_dynarray_append(chains, c);
}
}
}
}
struct pair {
unsigned i, j;
};
/*
* Find the most frequent pair in a chain. Tie break with the pair with the
* lowest max rank of the two operands. This is the meat of the CSE heuristic.
*/
static struct pair
find_best_pair_in_chain(struct chain *c, void *pair_freq)
{
struct pair best = {0};
unsigned best_max_rank = 0, best_freq = 0;
foreach_pair(c, i, j) {
struct pair_key key = get_pair_key(c->root->op, c->srcs[i], c->srcs[j]);
unsigned freq = lookup_pair_freq(pair_freq, key);
unsigned max_rank = MAX2(rank(c->srcs[i]), rank(c->srcs[j]));
if (freq > best_freq || (freq == best_freq && max_rank < best_max_rank)) {
best = (struct pair){i, j};
best_max_rank = max_rank;
best_freq = freq;
}
}
return best_freq > 1 ? best : (struct pair){0};
}
/* Compare ranks. Tie break to ensure the sort-by-rank sort is stable */
static int
cmp_rank(const void *a_, const void *b_)
{
const nir_scalar *a = a_, *b = b_;
int ra = rank(*a), rb = rank(*b);
if (ra != rb)
return ra - rb;
else
return a->def->index - b->def->index;
}
static bool
reassociate_chain(struct chain *c, void *pair_freq)
{
nir_builder b = nir_builder_at(nir_before_instr(&c->root->instr));
b.fp_math_ctrl = c->fp_math_ctrl;
/* Pick a new order using sort-by-rank and possibly the CSE heuristics */
unsigned pinned = 0;
if (c->do_global_cse) {
struct pair best_pair = find_best_pair_in_chain(c, pair_freq);
if (best_pair.i != best_pair.j) {
/* Pin the best pair at the front. The rest is sorted by rank. */
SWAP(c->srcs[0], c->srcs[best_pair.i]);
SWAP(c->srcs[1], c->srcs[best_pair.j]);
pinned = 2;
}
}
qsort(c->srcs + pinned, c->length - pinned, sizeof(c->srcs[0]), cmp_rank);
/* Reassociate according to the new order */
nir_def *new_root = nir_mov_scalar(&b, c->srcs[0]);
nir_def *last_src = NULL;
for (unsigned i = 1; i < c->length; ++i) {
nir_def *src = nir_mov_scalar(&b, c->srcs[i]);
/* If a source is duplicated in a chain, sort-by-rank groups the
* duplicates. Associate [x, y, y] as (x + (y + y)) to fuse FMA.
*/
if (i < c->length - 1 && nir_scalar_equal(c->srcs[i], c->srcs[i + 1])) {
src = nir_build_alu2(&b, c->root->op, src, src);
++i;
}
if (i < c->length - 1)
new_root = nir_build_alu2(&b, c->root->op, new_root, src);
else
last_src = src;
}
/* It is essential that the root itself is rewritten in place, rather than
* adding a new instruction and rewriting uses. The root may be used as a
* source in other chains, and we do all the analysis upfront, so we would
* get dangling references to the pre-rewrite root.
*
* For interior nodes, it doesn't matter, since nothing references them
* outside the chain by definition. The old instructions will be DCE'd.
*/
nir_alu_src_rewrite_scalar(&c->root->src[0], nir_get_scalar(last_src, 0));
nir_alu_src_rewrite_scalar(&c->root->src[1], nir_get_scalar(new_root, 0));
/* Set flags conservatively, matching the rest of the chain */
c->root->no_signed_wrap = c->root->no_unsigned_wrap = false;
c->root->fp_math_ctrl = c->fp_math_ctrl;
return true;
}
bool
nir_opt_reassociate(nir_shader *nir, nir_reassociate_options opts)
{
bool cse_heuristic = opts & nir_reassociate_cse_heuristic;
struct hash_table *pair_freq =
cse_heuristic ? pair_key_table_create(NULL) : NULL;
struct util_dynarray chains;
bool progress = false;
/* Clear pass flags. All instructions are possible roots, a priori. Interior
* nodes are indicated with a non-zero pass flags, set as we go.
*/
chains = UTIL_DYNARRAY_INIT;
nir_shader_clear_pass_flags(nir);
/* We use nir_def indices, which are function-local, so the algorithm runs on
* one function at a time.
*/
nir_foreach_function_impl(impl, nir) {
if (opts & nir_reassociate_scalar_math)
nir_metadata_require(impl, nir_metadata_divergence);
nir_index_ssa_defs(impl);
bool impl_progress = false;
_mesa_hash_table_clear(pair_freq, NULL);
util_dynarray_clear(&chains);
/* Step 1: find all chains in the function */
find_chains(impl, pair_freq, &chains);
/* Step 2: reassociate all chains */
util_dynarray_foreach(&chains, struct chain, chain) {
impl_progress |= reassociate_chain(chain, pair_freq);
}
nir_progress(impl_progress, impl, nir_metadata_control_flow);
progress |= impl_progress;
}
ralloc_free(pair_freq);
util_dynarray_fini(&chains);
return progress;
}
/* Helper loop for doing reassociation.
*
* You need to do two passes per set of flags you run with, because in the first
* pass we may reassociate constants together, which can then be folded,
* resulting in a cleanup, then you need to rebalance again. If you do both
* heuristics, you want to do the full CSE path first, then finalize with just
* the scalar math once the CSE work has been done.
*/
bool
nir_opt_reassociate_loop(nir_shader *nir, nir_reassociate_options in_opts)
{
bool any_progress = false;
for (unsigned i = 0; i < 2; ++i) {
nir_reassociate_options opts = in_opts;
if (i >= 1) {
if (in_opts == (nir_reassociate_cse_heuristic | nir_reassociate_scalar_math))
opts = nir_reassociate_scalar_math;
else
break;
}
/* For this set of opt flags, reassociate, then clean up constant folding,
* then re-reassociate to rebalance.
*/
for (unsigned j = 0; j < 2; j++) {
bool progress = false;
NIR_PASS(progress, nir, nir_opt_reassociate, opts);
if (!progress)
break;
any_progress = true;
do {
progress = false;
NIR_PASS(progress, nir, nir_opt_algebraic);
NIR_PASS(progress, nir, nir_opt_constant_folding);
NIR_PASS(progress, nir, nir_opt_copy_prop);
NIR_PASS(progress, nir, nir_opt_cse);
NIR_PASS(progress, nir, nir_opt_dce);
any_progress |= progress;
} while (progress);
}
}
return any_progress;
}