diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 5ccac8810de..683dbda0a13 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -2825,7 +2825,6 @@ agx_optimize_loop_nir(nir_shader *nir) NIR_PASS(progress, nir, nir_opt_algebraic); NIR_PASS(progress, nir, nir_opt_constant_folding); NIR_PASS(progress, nir, nir_opt_undef); - NIR_PASS(progress, nir, nir_opt_shrink_vectors, true); NIR_PASS(progress, nir, nir_opt_loop_unroll); } while (progress); } @@ -3006,6 +3005,14 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size) .callback = agx_mem_vectorize_cb, }); NIR_PASS(_, nir, nir_lower_pack); + NIR_PASS(_, nir, nir_opt_algebraic); + + /* Lower addressing modes. The sooner we do this, the sooner we get rid of + * amul/aadd instructions and can let nir_opt_algebraic do its job. But we + * want to vectorize first since nir_opt_load_store_vectorize doesn't know + * how to handle our loads. + */ + NIR_PASS(_, nir, agx_nir_lower_address); NIR_PASS_V(nir, nir_divergence_analysis); bool progress = false; @@ -3033,7 +3040,6 @@ agx_optimize_nir(nir_shader *nir, bool soft_fault, unsigned *preamble_size) } while (progress); progress = false; - NIR_PASS(progress, nir, agx_nir_lower_address); /* If address lowering made progress, clean up before forming preambles. * Otherwise the optimized preambles might just be constants! Do it before diff --git a/src/asahi/compiler/agx_compile.h b/src/asahi/compiler/agx_compile.h index 628a0b25cd6..aa60fa1faf7 100644 --- a/src/asahi/compiler/agx_compile.h +++ b/src/asahi/compiler/agx_compile.h @@ -321,10 +321,12 @@ static const nir_shader_compiler_options agx_nir_options = { .lower_hadd = true, .vectorize_io = true, .use_interpolated_input_intrinsics = true, + .has_amul = true, .has_isub = true, .support_16bit_alu = true, .max_unroll_iterations = 32, .lower_uniforms_to_ubo = true, + .late_lower_int64 = true, .lower_int64_options = (nir_lower_int64_options) ~(nir_lower_iadd64 | nir_lower_imul_2x32_64), .lower_doubles_options = (nir_lower_doubles_options)(~0), diff --git a/src/asahi/compiler/agx_compiler.h b/src/asahi/compiler/agx_compiler.h index d252622e03f..6c6fd0139b9 100644 --- a/src/asahi/compiler/agx_compiler.h +++ b/src/asahi/compiler/agx_compiler.h @@ -1081,7 +1081,6 @@ void agx_liveness_ins_update(BITSET_WORD *live, agx_instr *I); bool agx_nir_opt_preamble(nir_shader *s, unsigned *preamble_size); bool agx_nir_lower_load_mask(nir_shader *shader); -bool agx_nir_lower_address(nir_shader *shader); bool agx_nir_lower_ubo(nir_shader *shader); bool agx_nir_lower_shared_bitsize(nir_shader *shader); bool agx_nir_lower_frag_sidefx(nir_shader *s); diff --git a/src/asahi/compiler/agx_nir.h b/src/asahi/compiler/agx_nir.h index d63a3bf6c86..d7866ba39d0 100644 --- a/src/asahi/compiler/agx_nir.h +++ b/src/asahi/compiler/agx_nir.h @@ -9,7 +9,11 @@ struct nir_shader; +bool agx_nir_lower_address(struct nir_shader *shader); bool agx_nir_lower_algebraic_late(struct nir_shader *shader); +bool agx_nir_cleanup_amul(struct nir_shader *shader); +bool agx_nir_fuse_lea(struct nir_shader *shader); +bool agx_nir_lower_lea(struct nir_shader *shader); bool agx_nir_fuse_selects(struct nir_shader *shader); bool agx_nir_fuse_algebraic_late(struct nir_shader *shader); bool agx_nir_fence_images(struct nir_shader *shader); diff --git a/src/asahi/compiler/agx_nir_algebraic.py b/src/asahi/compiler/agx_nir_algebraic.py index 6711443b09c..9e2c5aa7e5f 100644 --- a/src/asahi/compiler/agx_nir_algebraic.py +++ b/src/asahi/compiler/agx_nir_algebraic.py @@ -195,6 +195,98 @@ ixor_bcsel = [ ('bcsel', a, ('ixor', b, d), ('ixor', c, d))), ] +# The main NIR optimizer works on imul, not iadd. We need just enough patterns +# for amul to let us fuse lea. +cleanup_amul = [ + # Neither operation overflows so we can keep the amul. + (('amul', ('amul', a, '#b'), '#c'), ('amul', a, ('imul', b, c))), +] + +fuse_lea = [] + +# Handle 64-bit address arithmetic (OpenCL) +for s in range(1, 5): + pot = 1 << s + + fuse_lea += [ + # A + (#b + c) 2^s = (A + c 2^s) + #b 2^s + (('iadd', 'a@64', ('amul', pot, ('iadd', '#b(is_upper_half_zero)', ('u2u64', 'c@32')))), + ('ulea_agx', ('ulea_agx', a, c, s), ('u2u32', b), s)), + + # A + (B + c) 2^s = (A + B 2^s) + c 2^s + (('iadd', 'a@64', ('amul', ('iadd', 'b@64', ('i2i64', 'c@32')), pot)), + ('ilea_agx', ('iadd', a, ('ishl', b, s)), c, s)), + + # A + 2^s (B + (C + d)) = (A + (B + C)2^s) + d 2^s + (('iadd', 'a@64', ('amul', ('iadd', 'b@64', + ('iadd', 'c@64', ('u2u64', 'd@32'))), pot)), + ('ulea_agx', ('iadd', a, ('ishl', ('iadd', b, c), s)), d, s)), + ] + + for sgn in ["u", "i"]: + upconv = f'{sgn}2{sgn}64' + lea = f'{sgn}lea_agx' + + fuse_lea += [ + # Basic pattern match + (('iadd', 'a@64', ('amul', (upconv, 'b@32'), pot)), (lea, a, b, s)), + (('iadd', 'a@64', ('ishl', (upconv, 'b@32'), s)), (lea, a, b, s)), + ] + +# Handle relaxed 32-bit address arithmetic (OpenGL, Vulkan) +for s_ in range(1, 5): + # Iterate backwards + s = 5 - s_ + + v = 1 << s + is_mult = f'(is_unsigned_multiple_of_{v})' + + fuse_lea += [ + # A + b * s = A + B * s with relaxed multiply + (('iadd', 'a@64', ('u2u64', ('amul', 'b@32', v))), + ('ulea_agx', a, b, s)), + + # A + (b * c 2^s) = A + (b * c) 2^s with relaxed multiply + (('iadd', 'a@64', ('u2u64', ('amul', 'b@32', f'#c{is_mult}'))), + ('ulea_agx', a, ('imul', b, ('ushr', c, s)), s)), + + # A + (b 2^s + c d 2^s) = A + (b + cd) 2^s with relaxation. + # + # amul is bounded by the buffer size by definition, and both the GL & VK + # limit UBOs and SSBOs to INT32_MAX bytes. Therefore, amul has no signed + # wrap. + # + # Further, because we are zero-extending the 32-bit result, the 32-bit + # sum must be nonnegative -- if it were negative, it would represent an + # offset above INT32_MAX which would be invalid given the amul and + # max buffer size. Thus with signed math + # + # 0 <= b 2^s + cd 2^s < INT32_MAX + # + # ..and hence + # + # 0 <= b + cd < INT32_MAX + # + # Those bounds together with distributivity mean that + # + # (b 2^s + cd 2^s) mod 2^32 = 2^s ((b + cd) mod 2^32) + # + # ...which is exactly what we need to factor out the shift. + (('iadd', 'a@64', ('u2u64', ('iadd', f'#b{is_mult}', + ('amul', 'c@32', f'#d{is_mult}')))), + ('ulea_agx', a, ('iadd', ('ishr', b, s), + ('amul', 'c@32', ('ishr', d, s))), s)), + ] + +# After lowering address arithmetic, the various address arithmetic opcodes are +# no longer useful. Lower them to regular arithmetic to let nir_opt_algebraic +# take over. +lower_lea = [ + (('amul', a, b), ('imul', a, b)), + (('ulea_agx', a, b, c), ('iadd', a, ('ishl', ('u2u64', b), c))), + (('ilea_agx', a, b, c), ('iadd', a, ('ishl', ('i2i64', b), c))), +] + def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--import-path', required=True) @@ -207,6 +299,10 @@ def run(): print('#include "agx_nir.h"') + print(nir_algebraic.AlgebraicPass("agx_nir_cleanup_amul", cleanup_amul).render()) + print(nir_algebraic.AlgebraicPass("agx_nir_fuse_lea", fuse_lea).render()) + print(nir_algebraic.AlgebraicPass("agx_nir_lower_lea", lower_lea).render()) + print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", lower_sm5_shift + lower_pack + lower_selects).render()) diff --git a/src/asahi/compiler/agx_nir_lower_address.c b/src/asahi/compiler/agx_nir_lower_address.c index b54689c8966..cdd9e306a6b 100644 --- a/src/asahi/compiler/agx_nir_lower_address.c +++ b/src/asahi/compiler/agx_nir_lower_address.c @@ -3,239 +3,19 @@ * SPDX-License-Identifier: MIT */ +#include #include "compiler/nir/nir_builder.h" -#include "agx_compiler.h" +#include "agx_nir.h" +#include "nir.h" +#include "nir_intrinsics.h" +#include "nir_opcodes.h" -/* Results of pattern matching */ struct match { nir_scalar base, offset; bool sign_extend; - - /* Signed shift. A negative shift indicates that the offset needs ushr - * applied. It's cheaper to fold iadd and materialize an extra ushr, than - * to leave the iadd untouched, so this is good. - */ - int8_t shift; + uint8_t shift; }; -/* - * Try to match a multiplication with an immediate value. This generalizes to - * both imul and ishl. If successful, returns true and sets the output - * variables. Otherwise, returns false. - */ -static bool -match_imul_imm(nir_scalar scalar, nir_scalar *variable, uint32_t *imm) -{ - if (!nir_scalar_is_alu(scalar)) - return false; - - nir_op op = nir_scalar_alu_op(scalar); - if (op != nir_op_imul && op != nir_op_ishl) - return false; - - nir_scalar inputs[] = { - nir_scalar_chase_alu_src(scalar, 0), - nir_scalar_chase_alu_src(scalar, 1), - }; - - /* For imul check both operands for an immediate, since imul is commutative. - * For ishl, only check the operand on the right. - */ - bool commutes = (op == nir_op_imul); - - for (unsigned i = commutes ? 0 : 1; i < ARRAY_SIZE(inputs); ++i) { - if (!nir_scalar_is_const(inputs[i])) - continue; - - *variable = inputs[1 - i]; - - uint32_t value = nir_scalar_as_uint(inputs[i]); - - if (op == nir_op_imul) - *imm = value; - else - *imm = (1 << value); - - return true; - } - - return false; -} - -/* - * Try to rewrite (a << (#b + #c)) + #d as ((a << #b) + #d') << #c, - * assuming that #d is a multiple of 1 << #c. This takes advantage of - * the hardware's implicit << #c and avoids a right-shift. - * - * Similarly, try to rewrite (a * (#b << #c)) + #d as ((a * #b) + #d') << #c. - * - * This pattern occurs with a struct-of-array layout. - */ -static bool -match_soa(nir_builder *b, struct match *match, unsigned format_shift) -{ - if (!nir_scalar_is_alu(match->offset) || - nir_scalar_alu_op(match->offset) != nir_op_iadd) - return false; - - nir_scalar summands[] = { - nir_scalar_chase_alu_src(match->offset, 0), - nir_scalar_chase_alu_src(match->offset, 1), - }; - - for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) { - if (!nir_scalar_is_const(summands[i])) - continue; - - /* Note: This is treated as signed regardless of the sign of the match. - * The final addition into the base can be signed or unsigned, but when - * we shift right by the format shift below we need to always sign extend - * to ensure that any negative offset remains negative when added into - * the index. That is, in: - * - * addr = base + (u64)((index + offset) << shift) - * - * `index` and `offset` are always 32 bits, and a negative `offset` needs - * to subtract from the index, so it needs to be sign extended when we - * apply the format shift regardless of the fact that the later conversion - * to 64 bits does not sign extend. - * - * TODO: We need to confirm how the hardware handles 32-bit overflow when - * applying the format shift, which might need rework here again. - */ - int offset = nir_scalar_as_int(summands[i]); - nir_scalar variable; - uint32_t multiplier; - - /* The other operand must multiply */ - if (!match_imul_imm(summands[1 - i], &variable, &multiplier)) - return false; - - int offset_shifted = offset >> format_shift; - uint32_t multiplier_shifted = multiplier >> format_shift; - - /* If the multiplier or the offset are not aligned, we can't rewrite */ - if (multiplier != (multiplier_shifted << format_shift)) - return false; - - if (offset != (offset_shifted << format_shift)) - return false; - - /* Otherwise, rewrite! */ - nir_def *unmultiplied = nir_vec_scalars(b, &variable, 1); - - nir_def *rewrite = nir_iadd_imm( - b, nir_imul_imm(b, unmultiplied, multiplier_shifted), offset_shifted); - - match->offset = nir_get_scalar(rewrite, 0); - match->shift = 0; - return true; - } - - return false; -} - -/* Try to pattern match address calculation */ -static struct match -match_address(nir_builder *b, nir_scalar base, int8_t format_shift) -{ - struct match match = {.base = base}; - - /* All address calculations are iadd at the root */ - if (!nir_scalar_is_alu(base) || nir_scalar_alu_op(base) != nir_op_iadd) - return match; - - /* Only 64+32 addition is supported, look for an extension */ - nir_scalar summands[] = { - nir_scalar_chase_alu_src(base, 0), - nir_scalar_chase_alu_src(base, 1), - }; - - for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) { - /* We can add a small constant to the 64-bit base for free */ - if (nir_scalar_is_const(summands[i]) && - nir_scalar_as_uint(summands[i]) < (1ull << 32)) { - - uint32_t value = nir_scalar_as_uint(summands[i]); - - return (struct match){ - .base = summands[1 - i], - .offset = nir_get_scalar(nir_imm_int(b, value), 0), - .shift = -format_shift, - .sign_extend = false, - }; - } - - /* Otherwise, we can only add an offset extended from 32-bits */ - if (!nir_scalar_is_alu(summands[i])) - continue; - - nir_op op = nir_scalar_alu_op(summands[i]); - - if (op != nir_op_u2u64 && op != nir_op_i2i64) - continue; - - /* We've found a summand, commit to it */ - match.base = summands[1 - i]; - match.offset = nir_scalar_chase_alu_src(summands[i], 0); - match.sign_extend = (op == nir_op_i2i64); - - /* Undo the implicit shift from using as offset */ - match.shift = -format_shift; - break; - } - - /* If we didn't find something to fold in, there's nothing else we can do */ - if (!match.offset.def) - return match; - - /* But if we did, we can try to fold in in a multiply */ - nir_scalar multiplied; - uint32_t multiplier; - - if (match_imul_imm(match.offset, &multiplied, &multiplier)) { - int8_t new_shift = match.shift; - - /* Try to fold in either a full power-of-two, or just the power-of-two - * part of a non-power-of-two stride. - */ - if (util_is_power_of_two_nonzero(multiplier)) { - new_shift += util_logbase2(multiplier); - multiplier = 1; - } else if (((multiplier >> format_shift) << format_shift) == multiplier) { - new_shift += format_shift; - multiplier >>= format_shift; - } else { - return match; - } - - nir_def *multiplied_ssa = nir_vec_scalars(b, &multiplied, 1); - - /* Only fold in if we wouldn't overflow the lsl field */ - if (new_shift <= 2) { - match.offset = - nir_get_scalar(nir_imul_imm(b, multiplied_ssa, multiplier), 0); - match.shift = new_shift; - } else if (new_shift > 0) { - /* For large shifts, we do need a multiply, but we can - * shrink the shift to avoid generating an ishr. - */ - assert(new_shift >= 3); - - nir_def *rewrite = - nir_imul_imm(b, multiplied_ssa, multiplier << new_shift); - - match.offset = nir_get_scalar(rewrite, 0); - match.shift = 0; - } - } else { - /* Try to match struct-of-arrays pattern, updating match if possible */ - match_soa(b, &match, format_shift); - } - - return match; -} - static enum pipe_format format_for_bitsize(unsigned bitsize) { @@ -271,37 +51,58 @@ pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data) nir_src *orig_offset = nir_get_io_offset_src(intr); nir_scalar base = nir_scalar_resolved(orig_offset->ssa, 0); - struct match match = match_address(b, base, format_shift); + struct match match = {.base = base}; + bool shift_must_match = + (intr->intrinsic == nir_intrinsic_global_atomic) || + (intr->intrinsic == nir_intrinsic_global_atomic_swap); + unsigned max_shift = format_shift + (shift_must_match ? 0 : 2); + + if (nir_scalar_is_alu(base)) { + nir_op op = nir_scalar_alu_op(base); + if (op == nir_op_ulea_agx || op == nir_op_ilea_agx) { + unsigned shift = nir_scalar_as_uint(nir_scalar_chase_alu_src(base, 2)); + if (shift >= format_shift && shift <= max_shift) { + match = (struct match){ + .base = nir_scalar_chase_alu_src(base, 0), + .offset = nir_scalar_chase_alu_src(base, 1), + .shift = shift - format_shift, + .sign_extend = (op == nir_op_ilea_agx), + }; + } + } else if (op == nir_op_iadd) { + for (unsigned i = 0; i < 2; ++i) { + nir_scalar const_scalar = nir_scalar_chase_alu_src(base, i); + if (!nir_scalar_is_const(const_scalar)) + continue; + + /* Put scalar into form (k*2^n), clamping n at the maximum hardware + * shift. + */ + int64_t raw_scalar = nir_scalar_as_uint(const_scalar); + uint32_t shift = MIN2(__builtin_ctz(raw_scalar), max_shift); + int64_t k = raw_scalar >> shift; + + /* See if the reduced scalar is from a sign extension. */ + if (k > INT32_MAX || k < INT32_MIN) + break; + + /* Match the constant */ + match = (struct match){ + .base = nir_scalar_chase_alu_src(base, 1 - i), + .offset = nir_get_scalar(nir_imm_int(b, k), 0), + .shift = shift - format_shift, + .sign_extend = true, + }; + + break; + } + } + } nir_def *offset = match.offset.def != NULL ? nir_channel(b, match.offset.def, match.offset.comp) : nir_imm_int(b, 0); - /* If we were unable to fold in the shift, insert a right-shift now to undo - * the implicit left shift of the instruction. - */ - if (match.shift < 0) { - if (match.sign_extend) - offset = nir_ishr_imm(b, offset, -match.shift); - else - offset = nir_ushr_imm(b, offset, -match.shift); - - match.shift = 0; - } - - /* Hardware offsets must be 32-bits. Upconvert if the source code used - * smaller integers. - */ - if (offset->bit_size != 32) { - assert(offset->bit_size < 32); - - if (match.sign_extend) - offset = nir_i2i32(b, offset); - else - offset = nir_u2u32(b, offset); - } - - assert(match.shift >= 0); nir_def *new_base = nir_channel(b, match.base.def, match.base.comp); nir_def *repl = NULL; @@ -321,13 +122,11 @@ pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data) .base = match.shift, .format = format, .sign_extend = match.sign_extend); } else if (intr->intrinsic == nir_intrinsic_global_atomic) { - offset = nir_ishl_imm(b, offset, match.shift); repl = nir_global_atomic_agx(b, bit_size, new_base, offset, intr->src[1].ssa, .atomic_op = nir_intrinsic_atomic_op(intr), .sign_extend = match.sign_extend); } else if (intr->intrinsic == nir_intrinsic_global_atomic_swap) { - offset = nir_ishl_imm(b, offset, match.shift); repl = nir_global_atomic_swap_agx( b, bit_size, new_base, offset, intr->src[1].ssa, intr->src[2].ssa, .atomic_op = nir_intrinsic_atomic_op(intr), @@ -346,8 +145,31 @@ pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data) } bool -agx_nir_lower_address(nir_shader *shader) +agx_nir_lower_address(nir_shader *nir) { - return nir_shader_intrinsics_pass(shader, pass, nir_metadata_control_flow, - NULL); + bool progress = false; + + /* First, clean up as much as possible. This will make fusing more effective. + */ + do { + progress = false; + NIR_PASS(progress, nir, agx_nir_cleanup_amul); + NIR_PASS(progress, nir, nir_opt_constant_folding); + NIR_PASS(progress, nir, nir_opt_dce); + } while (progress); + + /* Then, fuse as many lea as possible */ + NIR_PASS(progress, nir, agx_nir_fuse_lea); + + /* Next, lower load/store using the lea's */ + NIR_PASS(progress, nir, nir_shader_intrinsics_pass, pass, + nir_metadata_control_flow, NULL); + + /* Finally, lower any leftover lea instructions back to ALU to let + * nir_opt_algebraic simplify them from here. + */ + NIR_PASS(progress, nir, agx_nir_lower_lea); + NIR_PASS(progress, nir, nir_opt_dce); + + return progress; } diff --git a/src/gallium/drivers/asahi/agx_state.c b/src/gallium/drivers/asahi/agx_state.c index e4556bc7f16..5a39967e99c 100644 --- a/src/gallium/drivers/asahi/agx_state.c +++ b/src/gallium/drivers/asahi/agx_state.c @@ -1868,10 +1868,17 @@ agx_shader_initialize(struct agx_device *dev, struct agx_uncompiled_shader *so, so->info.cull_distance_size = nir->info.cull_distance_array_size; } - /* Vectorize SSBOs before lowering them, since it is significantly harder to - * vectorize the lowered code. + /* Shrink and vectorize SSBOs before lowering them, since it is harder to + * optimize the lowered code. */ + NIR_PASS(_, nir, nir_lower_alu_to_scalar, NULL, NULL); NIR_PASS(_, nir, nir_lower_load_const_to_scalar); + NIR_PASS(_, nir, agx_nir_cleanup_amul); + NIR_PASS(_, nir, nir_opt_constant_folding); + NIR_PASS(_, nir, nir_copy_prop); + NIR_PASS(_, nir, nir_opt_cse); + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(_, nir, nir_opt_shrink_vectors, true); NIR_PASS(_, nir, nir_copy_prop); NIR_PASS(