diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 081eb78f6d7..04476ecc932 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6072,6 +6072,7 @@ bool nir_opt_algebraic_before_lower_int64(nir_shader *shader); bool nir_opt_algebraic_late(nir_shader *shader); bool nir_opt_algebraic_distribute_src_mods(nir_shader *shader); bool nir_opt_algebraic_integer_promotion(nir_shader *shader); +bool nir_opt_reassociate_matrix_mul(nir_shader *shader); bool nir_opt_constant_folding(nir_shader *shader); /* Try to combine a and b into a. Return true if combination was possible, diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index e9744c7c6aa..4d6884861fd 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -3949,6 +3949,54 @@ distribute_src_mods = [ (('fabs', ('fsign(is_used_once)', a)), ('fsign', ('fabs', a))), ] +# Reassociate multiplication of mat*mat*vec. The typical case of +# mat4*mat4*vec4 is 80 scalar multiplications, but mat4*(mat4*vec4) is only 32. +mat_mul_optimizations = [] +for t in ['f', 'i']: + add_first = '~{}add(many-comm-expr)'.format(t) + add_used_once = '~{}add(is_used_once)'.format(t) + add = '~{}add'.format(t) + mul = '~{}mul'.format(t) + + # Variable names used below were selected based on these layouts: + # mat4 mat4 vec4 + # [a, b, c, d] [q, r, s, t] [gg] + # [e, f, g, h] [u, v, w, x] [hh] + # [i, j, k, l] x [y, z, aa, bb] x [ii] + # [m, n, o, p] [cc, dd, ee, ff] [jj] + + step1 = (add_used_once, (add, (add, (mul, 'a', 'q'), (mul, 'b', 'u')), (mul, 'c', 'y')), (mul, 'd', 'cc')) + step2 = (add_used_once, (add, (add, (mul, 'a', 'r'), (mul, 'b', 'v')), (mul, 'c', 'z')), (mul, 'd', 'dd')) + step3 = (add_used_once, (add, (add, (mul, 'a', 's'), (mul, 'b', 'w')), (mul, 'c', 'aa')), (mul, 'd', 'ee')) + step4 = (add_used_once, (add, (add, (mul, 'a', 't'), (mul, 'b', 'x')), (mul, 'c', 'bb')), (mul, 'd', 'ff')) + + step5 = (add, (add, (add, (mul, 'q', 'gg'), (mul, 'r', 'hh')), (mul, 's', 'ii')), (mul, 't', 'jj')) + step6 = (add, (add, (add, (mul, 'u', 'gg'), (mul, 'v', 'hh')), (mul, 'w', 'ii')), (mul, 'x', 'jj')) + step7 = (add, (add, (add, (mul, 'y', 'gg'), (mul, 'z', 'hh')), (mul, 'aa', 'ii')), (mul, 'bb', 'jj')) + step8 = (add, (add, (add, (mul, 'cc', 'gg'), (mul, 'dd', 'hh')), (mul, 'ee', 'ii')), (mul, 'ff', 'jj')) + + # This finds and replaces common (mat4*mat4)*vec4 with something that will get optimised down to mat4*(mat4*vec4) + mat_mul_optimizations += [((add_first, (add, (add, (mul, step1, 'gg'), (mul, step2, 'hh')), (mul, step3, 'ii')), (mul, step4, 'jj')), (add, (add, (add, (mul, step5, 'a'), (mul, step6, 'b')), (mul, step7, 'c')), (mul, step8, 'd')))] + + # This helps propagate the above improvement further up the mul chain e.g. mat4*mat4*mat4*vec4 to (mat4*vec4)*mat4*mat4 + mat_mul_optimizations += [((add_first, (add, (add, (mul, 'gg', step1), (mul,'hh', step2)), (mul, 'ii', step3)), (mul, 'jj', step4)), (add, (add, (add, (mul, step5, 'a'), (mul, step6, 'b')), (mul, step7, 'c')), (mul, step8, 'd')))] + + # Below handles a real world shader that looks like this mat4*mat4*vec4(xyz, 1.0) where the the multiplication of the 1.0 constant has been optimised away + step5_no_w_mul = (add, (add, (add, (mul, 'q', 'gg'), (mul, 'r', 'hh')), (mul, 's', 'ii')), 't') + step6_no_w_mul = (add, (add, (add, (mul, 'u', 'gg'), (mul, 'v', 'hh')), (mul, 'w', 'ii')), 'x') + step7_no_w_mul = (add, (add, (add, (mul, 'y', 'gg'), (mul, 'z', 'hh')), (mul, 'aa', 'ii')), 'bb') + step8_no_w_mul = (add, (add, (add, (mul, 'cc', 'gg'), (mul, 'dd', 'hh')), (mul, 'ee', 'ii')), 'ff') + + mat_mul_optimizations += [((add_first, (add, (add, (mul, step1, 'gg'), (mul, step2, 'hh')), (mul, step3, 'ii')), step4), (add, (add, (add, (mul, step5_no_w_mul, 'a'), (mul, step6_no_w_mul, 'b')), (mul, step7_no_w_mul, 'c')), (mul, step8_no_w_mul, 'd')))] + + # Below handles a real world shader that looks like this mat4*mat4*vec4(xy, 0.0, 1.0) where the the multiplication of the 0.0 and 1.0 constants have been optimised away + step5_zero_z_no_w_mul = (add, (add, (mul, 'q', 'gg'), (mul, 'r', 'hh')), 't') + step6_zero_z_no_w_mul = (add, (add, (mul, 'u', 'gg'), (mul, 'v', 'hh')), 'x') + step7_zero_z_no_w_mul = (add, (add, (mul, 'y', 'gg'), (mul, 'z', 'hh')), 'bb') + step8_zero_z_no_w_mul = (add, (add, (mul, 'cc', 'gg'), (mul, 'dd', 'hh')), 'ff') + + mat_mul_optimizations += [((add_first, (add, (mul, step1, 'gg'), step4), (mul, step2, 'hh')), (add, (add, (add, (mul, step5_zero_z_no_w_mul, 'a'), (mul, step6_zero_z_no_w_mul, 'b')), (mul, step7_zero_z_no_w_mul, 'c')), (mul, step8_zero_z_no_w_mul, 'd')))] + before_lower_int64_optimizations = [ # The i2i64(a) implies that 'a' has at most 32-bits of data. (('ishl', ('i2i64', a), b), @@ -4026,3 +4074,5 @@ with open(args.out, "w", encoding='utf-8') as f: distribute_src_mods).render()) f.write(nir_algebraic.AlgebraicPass("nir_opt_algebraic_integer_promotion", integer_promotion_optimizations).render()) + f.write(nir_algebraic.AlgebraicPass("nir_opt_reassociate_matrix_mul", + mat_mul_optimizations).render())