nir/algebraic: Clean up up-cast of down-cast when we can

There are a bunch of cases where we can pretty quickly determine that
the high bits don't matter.  In these cases, delete the casts.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/8872>
This commit is contained in:
Jason Ekstrand 2021-02-04 11:55:43 -06:00 committed by Marge Bot
parent 96303a59ea
commit f9b3be09e1
2 changed files with 51 additions and 0 deletions

View file

@ -1817,6 +1817,45 @@ optimizations += [
(('i2i8', ('iand', 'a@64', 0xff)), ('u2u8', a)),
]
# Some operations such as iadd have the property that the bottom N bits of the
# output only depends on the bottom N bits of each of the inputs so we can
# remove casts
for N in [16, 32]:
for M in [8, 16]:
if M >= N:
continue
aN = 'a@' + str(N)
u2uM = 'u2u{0}'.format(M)
i2iM = 'i2i{0}'.format(M)
for x in ['u', 'i']:
x2xN = '{0}2{0}{1}'.format(x, N)
extract_xM = 'extract_{0}{1}'.format(x, M)
x2xN_M_bits = '{0}(only_lower_{1}_bits_used)'.format(x2xN, M)
extract_xM_M_bits = \
'{0}(only_lower_{1}_bits_used)'.format(extract_xM, M)
optimizations += [
((x2xN_M_bits, (u2uM, aN)), a),
((extract_xM_M_bits, aN, 0), a),
]
bcsel_M_bits = 'bcsel(only_lower_{0}_bits_used)'.format(M)
optimizations += [
((bcsel_M_bits, c, (x2xN, (u2uM, aN)), b), ('bcsel', c, a, b)),
((bcsel_M_bits, c, (x2xN, (i2iM, aN)), b), ('bcsel', c, a, b)),
((bcsel_M_bits, c, (extract_xM, aN, 0), b), ('bcsel', c, a, b)),
]
for op in ['iadd', 'imul', 'iand', 'ior', 'ixor']:
op_M_bits = '{0}(only_lower_{1}_bits_used)'.format(op, M)
optimizations += [
((op_M_bits, (x2xN, (u2uM, aN)), b), (op, a, b)),
((op_M_bits, (x2xN, (i2iM, aN)), b), (op, a, b)),
((op_M_bits, (extract_xM, aN, 0), b), (op, a, b)),
]
def fexp2i(exp, bits):
# Generate an expression which constructs value 2.0^exp or 0.0.
#

View file

@ -337,6 +337,18 @@ is_only_used_as_float(nir_alu_instr *instr)
return true;
}
static inline bool
only_lower_8_bits_used(nir_alu_instr *instr)
{
return (nir_ssa_def_bits_used(&instr->dest.dest.ssa) & ~0xffull) == 0;
}
static inline bool
only_lower_16_bits_used(nir_alu_instr *instr)
{
return (nir_ssa_def_bits_used(&instr->dest.dest.ssa) & ~0xffffull) == 0;
}
/**
* Returns true if a NIR ALU src represents a constant integer
* of either 32 or 64 bits, and the higher word (bit-size / 2)