nir: add nir_opt_algebraic_integer_promotion

This handles basic operations where clang promotes integers to 32 bits
according to the C99 spec in OpenCL C source code.

This is its own opt_algerbraic pass, because we don't wanna fight with
nir_lower_bit_size.

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Christian Gmeiner <cgmeiner@igalia.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34641>
This commit is contained in:
Karol Herbst 2025-04-21 18:05:58 +02:00 committed by Marge Bot
parent 2582cf9971
commit f0fa2209a8
3 changed files with 28 additions and 0 deletions

View file

@ -6020,6 +6020,7 @@ bool nir_opt_algebraic_before_ffma(nir_shader *shader);
bool nir_opt_algebraic_before_lower_int64(nir_shader *shader);
bool nir_opt_algebraic_late(nir_shader *shader);
bool nir_opt_algebraic_distribute_src_mods(nir_shader *shader);
bool nir_opt_algebraic_integer_promotion(nir_shader *shader);
bool nir_opt_constant_folding(nir_shader *shader);
/* Try to combine a and b into a. Return true if combination was possible,

View file

@ -3896,6 +3896,30 @@ before_lower_int64_optimizations = [
(('iadd', ('u2u64', a), ('u2u64', a)), ('ishl', ('u2u64', a), 1)),
]
# Those optimizations try to reverse integer promotion found in e.g. OpenCL C. Those should be ran
# before any bit_size lowering is done.
integer_promotion_optimizations = []
for s in [8, 16]:
u2u = 'u2u{}'.format(s)
aN = 'a@{}'.format(s)
bN = 'b@{}'.format(s)
for op in ['ineg', 'inot']:
integer_promotion_optimizations.extend([
((u2u, (op, 'a@32')), (op, (u2u, a))),
])
for op in ['iadd', 'imul', 'iand', 'ior', 'ixor']:
integer_promotion_optimizations.extend([
((u2u, (op, 'a@32', 'b@32')), (op, (u2u, a), (u2u, b))),
])
# idiv and irem are more restrictive because we can't simply trim the inputs arbitrarily.
for op in ['idiv', 'irem']:
integer_promotion_optimizations.extend([
((u2u, (op, ('i2i32', aN), ('i2i32', bN))), (op, a, b)),
])
parser = argparse.ArgumentParser()
parser.add_argument('--out', required=True)
args = parser.parse_args()
@ -3910,3 +3934,5 @@ with open(args.out, "w", encoding='utf-8') as f:
late_optimizations).render())
f.write(nir_algebraic.AlgebraicPass("nir_opt_algebraic_distribute_src_mods",
distribute_src_mods).render())
f.write(nir_algebraic.AlgebraicPass("nir_opt_algebraic_integer_promotion",
integer_promotion_optimizations).render())

View file

@ -601,6 +601,7 @@ fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
nir_pass!(nir, nir_lower_alu);
progress |= nir_pass!(nir, nir_opt_phi_precision);
progress |= nir_pass!(nir, nir_opt_algebraic);
progress |= nir_pass!(nir, nir_opt_algebraic_integer_promotion);
progress |= nir_pass!(
nir,
nir_opt_if,