diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index d0df55ff954..48d2bdecdd1 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6020,6 +6020,7 @@ bool nir_opt_algebraic_before_ffma(nir_shader *shader); bool nir_opt_algebraic_before_lower_int64(nir_shader *shader); bool nir_opt_algebraic_late(nir_shader *shader); bool nir_opt_algebraic_distribute_src_mods(nir_shader *shader); +bool nir_opt_algebraic_integer_promotion(nir_shader *shader); bool nir_opt_constant_folding(nir_shader *shader); /* Try to combine a and b into a. Return true if combination was possible, diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 865b04916f4..13d18a3cb1d 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -3896,6 +3896,30 @@ before_lower_int64_optimizations = [ (('iadd', ('u2u64', a), ('u2u64', a)), ('ishl', ('u2u64', a), 1)), ] +# Those optimizations try to reverse integer promotion found in e.g. OpenCL C. Those should be ran +# before any bit_size lowering is done. +integer_promotion_optimizations = [] +for s in [8, 16]: + u2u = 'u2u{}'.format(s) + aN = 'a@{}'.format(s) + bN = 'b@{}'.format(s) + + for op in ['ineg', 'inot']: + integer_promotion_optimizations.extend([ + ((u2u, (op, 'a@32')), (op, (u2u, a))), + ]) + + for op in ['iadd', 'imul', 'iand', 'ior', 'ixor']: + integer_promotion_optimizations.extend([ + ((u2u, (op, 'a@32', 'b@32')), (op, (u2u, a), (u2u, b))), + ]) + + # idiv and irem are more restrictive because we can't simply trim the inputs arbitrarily. + for op in ['idiv', 'irem']: + integer_promotion_optimizations.extend([ + ((u2u, (op, ('i2i32', aN), ('i2i32', bN))), (op, a, b)), + ]) + parser = argparse.ArgumentParser() parser.add_argument('--out', required=True) args = parser.parse_args() @@ -3910,3 +3934,5 @@ with open(args.out, "w", encoding='utf-8') as f: late_optimizations).render()) f.write(nir_algebraic.AlgebraicPass("nir_opt_algebraic_distribute_src_mods", distribute_src_mods).render()) + f.write(nir_algebraic.AlgebraicPass("nir_opt_algebraic_integer_promotion", + integer_promotion_optimizations).render()) diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 1a778738a97..53211dfdd5d 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -601,6 +601,7 @@ fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) { nir_pass!(nir, nir_lower_alu); progress |= nir_pass!(nir, nir_opt_phi_precision); progress |= nir_pass!(nir, nir_opt_algebraic); + progress |= nir_pass!(nir, nir_opt_algebraic_integer_promotion); progress |= nir_pass!( nir, nir_opt_if,