mesa/src/compiler/nir/nir_opt_uub.c
Konstantin Seurer de32f9275f treewide: add & use parent instr helpers
We add a bunch of new helpers to avoid the need to touch >parent_instr,
including the full set of:

* nir_def_is_*
* nir_def_as_*_or_null
* nir_def_as_* [assumes the right instr type]
* nir_src_is_*
* nir_src_as_*
* nir_scalar_is_*
* nir_scalar_as_*

Plus nir_def_instr() where there's no more suitable helper.

Also an existing helper is renamed to unify all the names, while we're
churning the tree:

* nir_src_as_alu_instr -> nir_src_as_alu

..and then we port the tree to use the helpers as much as possible, using
nir_def_instr() where that does not work.

Acked-by: Marek Olšák <maraeo@gmail.com>

---

To eliminate nir_def::parent_instr we need to churn the tree anyway, so I'm
taking this opportunity to clean up a lot of NIR patterns.

Co-authored-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Signed-off-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38313>
2025-11-12 21:22:13 +00:00

328 lines
9.3 KiB
C

/*
* Copyright © 2025 Igalia S.L.
* SPDX-License-Identifier: MIT
*/
#include "nir.h"
#include "nir_builder.h"
typedef struct {
const nir_opt_uub_options *options;
nir_shader *shader;
struct hash_table *range_ht;
} opt_uub_state;
static uint32_t
uub(opt_uub_state *state, nir_scalar s)
{
return nir_unsigned_upper_bound(state->shader, state->range_ht, s);
}
static void
get_srcs(nir_alu_instr *alu, nir_scalar *srcs)
{
assert(alu->def.num_components == 1);
nir_scalar def = nir_get_scalar(&alu->def, 0);
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
srcs[i] = nir_scalar_chase_alu_src(def, i);
}
static bool
get_src_and_const(nir_alu_instr *alu, nir_scalar *src, nir_scalar *const_src,
unsigned *const_src_idx)
{
assert(nir_op_infos[alu->op].num_inputs == 2);
nir_scalar srcs[2];
get_srcs(alu, srcs);
for (unsigned i = 0; i < 2; i++) {
if (nir_scalar_is_const(srcs[i])) {
*const_src = srcs[i];
*src = srcs[1 - i];
if (const_src_idx)
*const_src_idx = i;
return true;
}
}
return false;
}
/* iand src, mask: if mask is constant with N least significant bits set and
* uub(src) < 2^N, the iand does nothing and can be removed.
*/
static bool
opt_uub_iand(nir_builder *b, nir_alu_instr *alu, opt_uub_state *state)
{
assert(alu->op == nir_op_iand);
nir_scalar src, mask;
if (!get_src_and_const(alu, &src, &mask, NULL))
return false;
unsigned first_0 = ffsll(~nir_scalar_as_uint(mask));
uint32_t low_mask = (1ull << (first_0 - 1)) - 1;
if (low_mask == 0)
return false;
if (uub(state, src) > low_mask)
return false;
b->cursor = nir_after_def(src.def);
nir_def_replace(&alu->def, nir_mov_scalar(b, src));
return true;
}
static nir_op
invert_cmp(nir_op op)
{
switch (op) {
case nir_op_ult:
return nir_op_uge;
case nir_op_uge:
return nir_op_ult;
case nir_op_ilt:
return nir_op_ige;
case nir_op_ige:
return nir_op_ilt;
default:
UNREACHABLE("unexpected cmp op");
}
}
/* ult src, const: if uub(src) < cmp -> true
* uge src, const: if uub(src) < cmp -> false
* ilt src, const: if uub(src) >= 0 && cmp < 0 -> false
* if uub(src) >= 0 && cmp >= 0 -> ult src, const
* ige src, const: if uub(src) >= 0 && cmp < 0 -> true
* if uub(src) >= 0 && cmp >= 0 -> uge src, const
*/
static bool
opt_uub_cmp(nir_builder *b, nir_alu_instr *alu, opt_uub_state *state)
{
assert(alu->op == nir_op_ult || alu->op == nir_op_uge ||
alu->op == nir_op_ilt || alu->op == nir_op_ige);
if (nir_src_bit_size(alu->src[0].src) > 32)
return false;
nir_scalar src, cmp;
unsigned const_src_idx;
if (!get_src_and_const(alu, &src, &cmp, &const_src_idx))
return false;
uint32_t src_uub = uub(state, src);
/* To make the code below more uniform, make sure the constant is always
* the RHS operand by inverting the opcode when it isn't.
*/
nir_op op = const_src_idx == 0 ? invert_cmp(alu->op) : alu->op;
if (op == nir_op_ilt || op == nir_op_ige) {
/* If src could be negative, there's nothing we can prove. */
if (util_sign_extend(src_uub, src.def->bit_size) < 0)
return false;
if (nir_scalar_as_int(cmp) < 0) {
/* src >= 0 && cmp < 0: ige -> true, ilt -> false */
bool replacement = op == nir_op_ige;
b->cursor = nir_after_instr(&alu->instr);
nir_def_replace(&alu->def, nir_imm_bool(b, replacement));
return true;
}
/* src >= 0 && cmp >= 0: same as unsigned cmp. */
op = op == nir_op_ilt ? nir_op_ult : nir_op_uge;
}
if (src_uub >= nir_scalar_as_uint(cmp))
return false;
/* Replace ult with true, uge with false. */
bool replacement = op == nir_op_ult;
b->cursor = nir_after_instr(&alu->instr);
nir_def_replace(&alu->def, nir_imm_bool(b, replacement));
return true;
}
/* umin src, const: if uub(src) <= const -> src
* umax src, const: if uub(src) <= const -> const
* imin src, const: if uub(src) >= 0 && const < 0 -> const
* if uub(src) >= 0 && const >= 0 -> umin src, const
* imax src, const: if uub(src) >= 0 && const < 0 -> src
* if uub(src) >= 0 && const >= 0 -> umax src, const
*/
static bool
opt_uub_minmax(nir_builder *b, nir_alu_instr *alu, opt_uub_state *state)
{
assert(alu->op == nir_op_umin || alu->op == nir_op_umax ||
alu->op == nir_op_imin || alu->op == nir_op_imax);
nir_scalar src, const_src;
if (!get_src_and_const(alu, &src, &const_src, NULL))
return false;
uint32_t src_uub = uub(state, src);
nir_op op = alu->op;
if (op == nir_op_imin || op == nir_op_imax) {
/* If src could be negative, there's nothing we can prove. */
if (util_sign_extend(src_uub, src.def->bit_size) < 0)
return false;
if (nir_scalar_as_int(const_src) < 0) {
/* src >= 0 && const < 0: imin -> const, imax -> src */
nir_scalar replacement = alu->op == nir_op_imin ? const_src : src;
b->cursor = nir_after_instr(&alu->instr);
nir_def_replace(&alu->def, nir_mov_scalar(b, replacement));
return true;
}
/* src >= 0 && cmp >= 0: same as umin/umax. */
op = op == nir_op_imin ? nir_op_umin : nir_op_umax;
}
if (src_uub > nir_scalar_as_uint(const_src))
return false;
nir_scalar replacement = alu->op == nir_op_umax ? const_src : src;
b->cursor = nir_after_instr(&alu->instr);
nir_def_replace(&alu->def, nir_mov_scalar(b, replacement));
return true;
}
static bool
try_replace_imul(nir_builder *b, nir_alu_instr *alu, nir_scalar *srcs,
uint32_t *src_uubs, unsigned bits_used, nir_op op)
{
uint32_t max = (1 << bits_used) - 1;
if (src_uubs[0] > max || src_uubs[1] > max)
return false;
b->cursor = nir_after_instr(&alu->instr);
nir_def_replace(&alu->def, nir_build_alu2(b, op, nir_mov_scalar(b, srcs[0]),
nir_mov_scalar(b, srcs[1])));
return true;
}
/* imul src0, src1: if uub(srci) < UINT16_MAX -> umul_16x16 src0, src1
* imul src0, src1: if uub(srci) < UINT24_MAX -> umul24 src0, src1
* imul src0, src1: if uub(srci) < UINT23_MAX -> imul24 src0, src1
*/
static bool
opt_uub_imul(nir_builder *b, nir_alu_instr *alu, opt_uub_state *state)
{
assert(alu->op == nir_op_imul);
if (!state->options->opt_imul || alu->def.bit_size != 32)
return false;
const nir_shader_compiler_options *shader_options = b->shader->options;
nir_scalar srcs[2];
get_srcs(alu, srcs);
uint32_t src_uubs[] = {uub(state, srcs[0]), uub(state, srcs[1])};
if (shader_options->has_umul_16x16 &&
try_replace_imul(b, alu, srcs, src_uubs, 16, nir_op_umul_16x16)) {
return true;
}
if ((shader_options->has_umul24 || shader_options->has_mul24_relaxed) &&
try_replace_imul(b, alu, srcs, src_uubs, 24,
b->shader->options->has_mul24_relaxed
? nir_op_umul24_relaxed
: nir_op_umul24)) {
return true;
}
/* imul24 sign-extends its 24-bit sources which may give the wrong result
* if the unsigned upper bound fits in 24-bits. Check if it fits is 23-bits
* to avoid sign-extension.
*/
if ((shader_options->has_imul24 || shader_options->has_mul24_relaxed) &&
try_replace_imul(b, alu, srcs, src_uubs, 23,
b->shader->options->has_mul24_relaxed
? nir_op_imul24_relaxed
: nir_op_imul24)) {
return true;
}
return false;
}
static bool
src_is_const(nir_src *src, void *data)
{
return nir_src_is_const(*src);
}
static bool
opt_uub(nir_builder *b, nir_alu_instr *alu, void *data)
{
/* nir_unsigned_upper_bound calculates 32-bit upper bounds so ignore 64-bit
* instructions. Also ignore non-scalar instructions to simplify the code.
*/
if (alu->def.bit_size > 32 || alu->def.num_components > 1)
return false;
/* If all sources are constant, let constant folding handle this. */
if (nir_foreach_src(&alu->instr, src_is_const, NULL))
return false;
opt_uub_state *state = data;
/* If the upper bound is zero, zero is the only possible value. */
if (uub(state, nir_get_scalar(&alu->def, 0)) == 0) {
b->cursor = nir_after_def(&alu->def);
nir_def_replace(&alu->def, nir_imm_zero(b, 1, alu->def.bit_size));
return true;
}
switch (alu->op) {
case nir_op_iand:
return opt_uub_iand(b, alu, state);
case nir_op_ult:
case nir_op_uge:
case nir_op_ilt:
case nir_op_ige:
return opt_uub_cmp(b, alu, state);
case nir_op_umin:
case nir_op_umax:
case nir_op_imin:
case nir_op_imax:
return opt_uub_minmax(b, alu, state);
case nir_op_imul:
return opt_uub_imul(b, alu, state);
default:
return false;
}
}
/* Performs a number of optimizations that make use of
* nir_unsigned_upper_bound to simplify/remove instructions.
*/
bool
nir_opt_uub(nir_shader *shader, const nir_opt_uub_options *options)
{
opt_uub_state state = {
.options = options,
.shader = shader,
.range_ht = _mesa_pointer_hash_table_create(NULL),
};
bool progress =
nir_shader_alu_pass(shader, opt_uub, nir_metadata_control_flow, &state);
_mesa_hash_table_destroy(state.range_ht, NULL);
return progress;
}