nir/opt_algebraic: look through fabs/fneg when matching fmulz/ffmaz

Prevents regressions when removing input modifiers from a == 0.0.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29467>
This commit is contained in:
Georg Lehmann 2024-05-30 22:11:11 +02:00 committed by Marge Bot
parent 080e03d021
commit 8e6bf596cb

View file

@ -102,6 +102,49 @@ def lowered_sincos(c):
def intBitsToFloat(i):
return struct.unpack('!f', struct.pack('!I', i))[0]
# Takes a pattern as input and returns a list of patterns where each
# pattern has a different permutation of fneg/fabs(value) as the replacement
# for the key operands in replacements.
def add_fabs_fneg(pattern, replacements, commutative = True):
def to_list(pattern):
return [to_list(i) if isinstance(i, tuple) else i for i in pattern]
def to_tuple(pattern):
return tuple(to_tuple(i) if isinstance(i, list) else i for i in pattern)
def replace_varible(pattern, search, replace):
for i in range(len(pattern)):
if pattern[i] == search:
pattern[i] = replace
elif isinstance(pattern[i], list):
replace_varible(pattern[i], search, replace)
if commutative:
perms = itertools.combinations_with_replacement(range(4), len(replacements))
else:
perms = itertools.product(range(4), repeat=len(replacements))
result = []
for perm in perms:
curr = to_list(pattern)
for i, (search, base) in enumerate(replacements.items()):
if perm[i] == 0:
replace = ['fneg', ['fabs', base]]
elif perm[i] == 1:
replace = ['fabs', base]
elif perm[i] == 2:
replace = ['fneg', base]
elif perm[i] == 3:
replace = base
replace_varible(curr, search, replace)
result.append(to_tuple(curr))
return result
optimizations = [
(('imul', a, '#b(is_pos_power_of_two)'), ('ishl', a, ('find_lsb', b)), '!options->lower_bitops'),
@ -274,21 +317,21 @@ optimizations = [
# Optimize open-coded fmulz.
# (b==0.0 ? 0.0 : a) * (a==0.0 ? 0.0 : b) -> fmulz(a, b)
(('fmul@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, a), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, b)),
('fmulz', a, b), has_fmulz),
(('fmul@32(nsz)', a, ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)')),
('fmulz', a, b), has_fmulz),
*add_fabs_fneg((('fmul@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, 'ma'), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, 'mb')),
('fmulz', 'ma', 'mb'), has_fmulz), {'ma' : a, 'mb' : b}),
*add_fabs_fneg((('fmul@32(nsz)', 'ma', ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)')),
('fmulz', 'ma', b), has_fmulz), {'ma' : a}),
# ffma(b==0.0 ? 0.0 : a, a==0.0 ? 0.0 : b, c) -> ffmaz(a, b, c)
(('ffma@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, a), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, b), c),
('ffmaz', a, b, c), has_fmulz),
(('ffma@32(nsz)', a, ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)'), c),
('ffmaz', a, b, c), has_fmulz),
*add_fabs_fneg((('ffma@32(nsz)', ('bcsel', ignore_exact('feq', b, 0.0), 0.0, 'ma'), ('bcsel', ignore_exact('feq', a, 0.0), 0.0, 'mb'), c),
('ffmaz', 'ma', 'mb', c), has_fmulz), {'ma' : a, 'mb' : b}),
*add_fabs_fneg((('ffma@32(nsz)', 'ma', ('bcsel', ignore_exact('feq', a, 0.0), 0.0, '#b(is_not_const_zero)'), c),
('ffmaz', 'ma', b, c), has_fmulz), {'ma' : a}),
# b == 0.0 ? 1.0 : fexp2(fmul(a, b)) -> fexp2(fmulz(a, b))
(('bcsel(nsz,nnan,ninf)', ignore_exact('feq', b, 0.0), 1.0, ('fexp2', ('fmul@32', a, b))),
('fexp2', ('fmulz', a, b)),
has_fmulz),
*add_fabs_fneg((('bcsel(nsz,nnan,ninf)', ignore_exact('feq', b, 0.0), 1.0, ('fexp2', ('fmul@32', a, 'mb'))),
('fexp2', ('fmulz', a, 'mb')),
has_fmulz), {'mb': b}),
]
# Shorthand for the expansion of just the dot product part of the [iu]dp4a