diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 42462d5befa..abefbb54756 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -1037,6 +1037,73 @@ for N, M in itertools.product(type_sizes('uint'), type_sizes('uint')): # The N == M case is handled by other optimizations pass +# Optimize comparisons with up-casts +for t in ['int', 'uint', 'float']: + for N, M in itertools.product(type_sizes(t), repeat=2): + if N == 1 or N >= M: + continue + + x2xM = '{0}2{0}{1}'.format(t[0], M) + x2xN = '{0}2{0}{1}'.format(t[0], N) + aN = 'a@' + str(N) + bN = 'b@' + str(N) + xeq = 'feq' if t == 'float' else 'ieq' + xne = 'fne' if t == 'float' else 'ine' + xge = '{0}ge'.format(t[0]) + xlt = '{0}lt'.format(t[0]) + + # Up-casts are lossless so for correctly signed comparisons of + # up-casted values we can do the comparison at the largest of the two + # original sizes and drop one or both of the casts. (We have + # optimizations to drop the no-op casts which this may generate.) + for P in type_sizes(t): + if P == 1 or P > N: + continue + + bP = 'b@' + str(P) + optimizations += [ + ((xeq, (x2xM, aN), (x2xM, bP)), (xeq, a, (x2xN, b))), + ((xne, (x2xM, aN), (x2xM, bP)), (xne, a, (x2xN, b))), + ((xge, (x2xM, aN), (x2xM, bP)), (xge, a, (x2xN, b))), + ((xlt, (x2xM, aN), (x2xM, bP)), (xlt, a, (x2xN, b))), + ((xge, (x2xM, bP), (x2xM, aN)), (xge, (x2xN, b), a)), + ((xlt, (x2xM, bP), (x2xM, aN)), (xlt, (x2xN, b), a)), + ] + + # The next bit doesn't work on floats because the range checks would + # get way too complicated. + if t in ['int', 'uint']: + if t == 'int': + xN_min = -(1 << (N - 1)) + xN_max = (1 << (N - 1)) - 1 + elif t == 'uint': + xN_min = 0 + xN_max = (1 << N) - 1 + else: + assert False + + # If we're up-casting and comparing to a constant, we can unfold + # the comparison into a comparison with the shrunk down constant + # and a check that the constant fits in the smaller bit size. + optimizations += [ + ((xeq, (x2xM, aN), '#b'), + ('iand', (xeq, a, (x2xN, b)), (xeq, (x2xM, (x2xN, b)), b))), + ((xne, (x2xM, aN), '#b'), + ('ior', (xne, a, (x2xN, b)), (xne, (x2xM, (x2xN, b)), b))), + ((xlt, (x2xM, aN), '#b'), + ('iand', (xlt, xN_min, b), + ('ior', (xlt, xN_max, b), (xlt, a, (x2xN, b))))), + ((xlt, '#a', (x2xM, bN)), + ('iand', (xlt, a, xN_max), + ('ior', (xlt, a, xN_min), (xlt, (x2xN, a), b)))), + ((xge, (x2xM, aN), '#b'), + ('iand', (xge, xN_max, b), + ('ior', (xge, xN_min, b), (xge, a, (x2xN, b))))), + ((xge, '#a', (x2xM, bN)), + ('iand', (xge, a, xN_min), + ('ior', (xge, a, xN_max), (xge, (x2xN, a), b)))), + ] + def fexp2i(exp, bits): # We assume that exp is already in the right range. if bits == 16: