nir: Add a pass to reassociate multiplication of mat*mat*vec.

The typical case of mat4*mat4*vec4 is 80 scalar multiplications, but
mat4*(mat4*vec4) is only 32.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35622>
This commit is contained in:
Emma Anholt 2025-06-18 16:05:01 -07:00 committed by Marge Bot
parent 21ea8c205f
commit bc8994cb48
2 changed files with 51 additions and 0 deletions

View file

@ -6072,6 +6072,7 @@ bool nir_opt_algebraic_before_lower_int64(nir_shader *shader);
bool nir_opt_algebraic_late(nir_shader *shader);
bool nir_opt_algebraic_distribute_src_mods(nir_shader *shader);
bool nir_opt_algebraic_integer_promotion(nir_shader *shader);
bool nir_opt_reassociate_matrix_mul(nir_shader *shader);
bool nir_opt_constant_folding(nir_shader *shader);
/* Try to combine a and b into a. Return true if combination was possible,

View file

@ -3949,6 +3949,54 @@ distribute_src_mods = [
(('fabs', ('fsign(is_used_once)', a)), ('fsign', ('fabs', a))),
]
# Reassociate multiplication of mat*mat*vec. The typical case of
# mat4*mat4*vec4 is 80 scalar multiplications, but mat4*(mat4*vec4) is only 32.
mat_mul_optimizations = []
for t in ['f', 'i']:
add_first = '~{}add(many-comm-expr)'.format(t)
add_used_once = '~{}add(is_used_once)'.format(t)
add = '~{}add'.format(t)
mul = '~{}mul'.format(t)
# Variable names used below were selected based on these layouts:
# mat4 mat4 vec4
# [a, b, c, d] [q, r, s, t] [gg]
# [e, f, g, h] [u, v, w, x] [hh]
# [i, j, k, l] x [y, z, aa, bb] x [ii]
# [m, n, o, p] [cc, dd, ee, ff] [jj]
step1 = (add_used_once, (add, (add, (mul, 'a', 'q'), (mul, 'b', 'u')), (mul, 'c', 'y')), (mul, 'd', 'cc'))
step2 = (add_used_once, (add, (add, (mul, 'a', 'r'), (mul, 'b', 'v')), (mul, 'c', 'z')), (mul, 'd', 'dd'))
step3 = (add_used_once, (add, (add, (mul, 'a', 's'), (mul, 'b', 'w')), (mul, 'c', 'aa')), (mul, 'd', 'ee'))
step4 = (add_used_once, (add, (add, (mul, 'a', 't'), (mul, 'b', 'x')), (mul, 'c', 'bb')), (mul, 'd', 'ff'))
step5 = (add, (add, (add, (mul, 'q', 'gg'), (mul, 'r', 'hh')), (mul, 's', 'ii')), (mul, 't', 'jj'))
step6 = (add, (add, (add, (mul, 'u', 'gg'), (mul, 'v', 'hh')), (mul, 'w', 'ii')), (mul, 'x', 'jj'))
step7 = (add, (add, (add, (mul, 'y', 'gg'), (mul, 'z', 'hh')), (mul, 'aa', 'ii')), (mul, 'bb', 'jj'))
step8 = (add, (add, (add, (mul, 'cc', 'gg'), (mul, 'dd', 'hh')), (mul, 'ee', 'ii')), (mul, 'ff', 'jj'))
# This finds and replaces common (mat4*mat4)*vec4 with something that will get optimised down to mat4*(mat4*vec4)
mat_mul_optimizations += [((add_first, (add, (add, (mul, step1, 'gg'), (mul, step2, 'hh')), (mul, step3, 'ii')), (mul, step4, 'jj')), (add, (add, (add, (mul, step5, 'a'), (mul, step6, 'b')), (mul, step7, 'c')), (mul, step8, 'd')))]
# This helps propagate the above improvement further up the mul chain e.g. mat4*mat4*mat4*vec4 to (mat4*vec4)*mat4*mat4
mat_mul_optimizations += [((add_first, (add, (add, (mul, 'gg', step1), (mul,'hh', step2)), (mul, 'ii', step3)), (mul, 'jj', step4)), (add, (add, (add, (mul, step5, 'a'), (mul, step6, 'b')), (mul, step7, 'c')), (mul, step8, 'd')))]
# Below handles a real world shader that looks like this mat4*mat4*vec4(xyz, 1.0) where the the multiplication of the 1.0 constant has been optimised away
step5_no_w_mul = (add, (add, (add, (mul, 'q', 'gg'), (mul, 'r', 'hh')), (mul, 's', 'ii')), 't')
step6_no_w_mul = (add, (add, (add, (mul, 'u', 'gg'), (mul, 'v', 'hh')), (mul, 'w', 'ii')), 'x')
step7_no_w_mul = (add, (add, (add, (mul, 'y', 'gg'), (mul, 'z', 'hh')), (mul, 'aa', 'ii')), 'bb')
step8_no_w_mul = (add, (add, (add, (mul, 'cc', 'gg'), (mul, 'dd', 'hh')), (mul, 'ee', 'ii')), 'ff')
mat_mul_optimizations += [((add_first, (add, (add, (mul, step1, 'gg'), (mul, step2, 'hh')), (mul, step3, 'ii')), step4), (add, (add, (add, (mul, step5_no_w_mul, 'a'), (mul, step6_no_w_mul, 'b')), (mul, step7_no_w_mul, 'c')), (mul, step8_no_w_mul, 'd')))]
# Below handles a real world shader that looks like this mat4*mat4*vec4(xy, 0.0, 1.0) where the the multiplication of the 0.0 and 1.0 constants have been optimised away
step5_zero_z_no_w_mul = (add, (add, (mul, 'q', 'gg'), (mul, 'r', 'hh')), 't')
step6_zero_z_no_w_mul = (add, (add, (mul, 'u', 'gg'), (mul, 'v', 'hh')), 'x')
step7_zero_z_no_w_mul = (add, (add, (mul, 'y', 'gg'), (mul, 'z', 'hh')), 'bb')
step8_zero_z_no_w_mul = (add, (add, (mul, 'cc', 'gg'), (mul, 'dd', 'hh')), 'ff')
mat_mul_optimizations += [((add_first, (add, (mul, step1, 'gg'), step4), (mul, step2, 'hh')), (add, (add, (add, (mul, step5_zero_z_no_w_mul, 'a'), (mul, step6_zero_z_no_w_mul, 'b')), (mul, step7_zero_z_no_w_mul, 'c')), (mul, step8_zero_z_no_w_mul, 'd')))]
before_lower_int64_optimizations = [
# The i2i64(a) implies that 'a' has at most 32-bits of data.
(('ishl', ('i2i64', a), b),
@ -4026,3 +4074,5 @@ with open(args.out, "w", encoding='utf-8') as f:
distribute_src_mods).render())
f.write(nir_algebraic.AlgebraicPass("nir_opt_algebraic_integer_promotion",
integer_promotion_optimizations).render())
f.write(nir_algebraic.AlgebraicPass("nir_opt_reassociate_matrix_mul",
mat_mul_optimizations).render())