From f9b3be09e15c35149ac5395b0419186580381019 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Thu, 4 Feb 2021 11:55:43 -0600 Subject: [PATCH] nir/algebraic: Clean up up-cast of down-cast when we can There are a bunch of cases where we can pretty quickly determine that the high bits don't matter. In these cases, delete the casts. Reviewed-by: Ian Romanick Part-of: --- src/compiler/nir/nir_opt_algebraic.py | 39 +++++++++++++++++++++++++++ src/compiler/nir/nir_search_helpers.h | 12 +++++++++ 2 files changed, 51 insertions(+) diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index c719cd096f8..51ce418dea2 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -1817,6 +1817,45 @@ optimizations += [ (('i2i8', ('iand', 'a@64', 0xff)), ('u2u8', a)), ] +# Some operations such as iadd have the property that the bottom N bits of the +# output only depends on the bottom N bits of each of the inputs so we can +# remove casts +for N in [16, 32]: + for M in [8, 16]: + if M >= N: + continue + + aN = 'a@' + str(N) + u2uM = 'u2u{0}'.format(M) + i2iM = 'i2i{0}'.format(M) + + for x in ['u', 'i']: + x2xN = '{0}2{0}{1}'.format(x, N) + extract_xM = 'extract_{0}{1}'.format(x, M) + + x2xN_M_bits = '{0}(only_lower_{1}_bits_used)'.format(x2xN, M) + extract_xM_M_bits = \ + '{0}(only_lower_{1}_bits_used)'.format(extract_xM, M) + optimizations += [ + ((x2xN_M_bits, (u2uM, aN)), a), + ((extract_xM_M_bits, aN, 0), a), + ] + + bcsel_M_bits = 'bcsel(only_lower_{0}_bits_used)'.format(M) + optimizations += [ + ((bcsel_M_bits, c, (x2xN, (u2uM, aN)), b), ('bcsel', c, a, b)), + ((bcsel_M_bits, c, (x2xN, (i2iM, aN)), b), ('bcsel', c, a, b)), + ((bcsel_M_bits, c, (extract_xM, aN, 0), b), ('bcsel', c, a, b)), + ] + + for op in ['iadd', 'imul', 'iand', 'ior', 'ixor']: + op_M_bits = '{0}(only_lower_{1}_bits_used)'.format(op, M) + optimizations += [ + ((op_M_bits, (x2xN, (u2uM, aN)), b), (op, a, b)), + ((op_M_bits, (x2xN, (i2iM, aN)), b), (op, a, b)), + ((op_M_bits, (extract_xM, aN, 0), b), (op, a, b)), + ] + def fexp2i(exp, bits): # Generate an expression which constructs value 2.0^exp or 0.0. # diff --git a/src/compiler/nir/nir_search_helpers.h b/src/compiler/nir/nir_search_helpers.h index 323f31a9b6a..5d5fe90ea57 100644 --- a/src/compiler/nir/nir_search_helpers.h +++ b/src/compiler/nir/nir_search_helpers.h @@ -337,6 +337,18 @@ is_only_used_as_float(nir_alu_instr *instr) return true; } +static inline bool +only_lower_8_bits_used(nir_alu_instr *instr) +{ + return (nir_ssa_def_bits_used(&instr->dest.dest.ssa) & ~0xffull) == 0; +} + +static inline bool +only_lower_16_bits_used(nir_alu_instr *instr) +{ + return (nir_ssa_def_bits_used(&instr->dest.dest.ssa) & ~0xffffull) == 0; +} + /** * Returns true if a NIR ALU src represents a constant integer * of either 32 or 64 bits, and the higher word (bit-size / 2)