diff --git a/src/nouveau/compiler/meson.build b/src/nouveau/compiler/meson.build index 2a77a6e921c..a6c95c5be6d 100644 --- a/src/nouveau/compiler/meson.build +++ b/src/nouveau/compiler/meson.build @@ -22,6 +22,7 @@ libnak_c_files = files( 'nak.h', 'nak_nir.c', 'nak_nir_add_barriers.c', + 'nak_nir_lower_scan_reduce.c', 'nak_nir_lower_tex.c', 'nak_nir_lower_vtg_io.c', 'nak_nir_lower_gs_intrinsics.c', diff --git a/src/nouveau/compiler/nak_nir.c b/src/nouveau/compiler/nak_nir.c index a6639eabd7b..ef748adfdb1 100644 --- a/src/nouveau/compiler/nak_nir.c +++ b/src/nouveau/compiler/nak_nir.c @@ -297,6 +297,7 @@ nak_preprocess_nir(nir_shader *nir, const struct nak_compiler *nak) .lower_inverse_ballot = true, }; OPT(nir, nir_lower_subgroups, &subgroups_options); + OPT(nir, nak_nir_lower_scan_reduce); } static uint16_t diff --git a/src/nouveau/compiler/nak_nir_lower_scan_reduce.c b/src/nouveau/compiler/nak_nir_lower_scan_reduce.c new file mode 100644 index 00000000000..2d5923190d6 --- /dev/null +++ b/src/nouveau/compiler/nak_nir_lower_scan_reduce.c @@ -0,0 +1,197 @@ +/* + * Copyright © 2023 Collabora, Ltd. + * SPDX-License-Identifier: MIT + */ + +#include "nak_private.h" +#include "nir_builder.h" + +static nir_def * +build_identity(nir_builder *b, nir_op op) +{ + nir_const_value ident_const = nir_alu_binop_identity(op, 32); + return nir_build_imm(b, 1, 32, &ident_const); +} + +/* Implementation of scan/reduce that assumes a full subgroup */ +static nir_def * +build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op, + nir_def *data, unsigned cluster_size) +{ + switch (op) { + case nir_intrinsic_exclusive_scan: + case nir_intrinsic_inclusive_scan: { + for (unsigned i = 1; i < cluster_size; i *= 2) { + nir_def *idx = nir_load_subgroup_invocation(b); + nir_def *has_buddy = nir_ige_imm(b, idx, i); + + nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i)); + nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data); + data = nir_bcsel(b, has_buddy, accum, data); + } + + if (op == nir_intrinsic_exclusive_scan) { + /* For exclusive scans, we need to shift one more time and fill in the + * bottom channel with identity. + */ + assert(cluster_size == 32); + nir_def *idx = nir_load_subgroup_invocation(b); + nir_def *has_buddy = nir_ige_imm(b, idx, 1); + + nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1)); + data = nir_bcsel(b, has_buddy, buddy_data, build_identity(b, red_op)); + } + + return data; + } + + case nir_intrinsic_reduce: { + for (unsigned i = 1; i < cluster_size; i *= 2) { + nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i)); + data = nir_build_alu2(b, red_op, data, buddy_data); + } + return data; + } + + default: + unreachable("Unsupported scan/reduce op"); + } +} + +/* Fully generic implementation of scan/reduce that takes a mask */ +static nir_def * +build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, + nir_def *data, nir_def *mask, unsigned max_mask_bits) +{ + nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, 32); + + /* Mask of all channels whose values we need to accumulate. Our own value + * is already in accum, if inclusive, thanks to the initialization above. + * We only need to consider lower indexed invocations. + */ + nir_def *remaining = nir_iand(b, mask, lt_mask); + + for (unsigned i = 1; i < max_mask_bits; i *= 2) { + /* At each step, our buddy channel is the first channel we have yet to + * take into account in the accumulator. + */ + nir_def *has_buddy = nir_ine_imm(b, remaining, 0); + nir_def *buddy = nir_ufind_msb(b, remaining); + + /* Accumulate with our buddy channel, if any */ + nir_def *buddy_data = nir_shuffle(b, data, buddy); + nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data); + data = nir_bcsel(b, has_buddy, accum, data); + + /* We just took into account everything in our buddy's accumulator from + * the previous step. The only things remaining are whatever channels + * were remaining for our buddy. + */ + nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy); + remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0)); + } + + switch (op) { + case nir_intrinsic_exclusive_scan: { + /* For exclusive scans, we need to shift one more time and fill in the + * bottom channel with identity. + * + * Some of this will get CSE'd with the first step but that's okay. The + * code is cleaner this way. + */ + nir_def *lower = nir_iand(b, mask, lt_mask); + nir_def *has_buddy = nir_ine_imm(b, lower, 0); + nir_def *buddy = nir_ufind_msb(b, lower); + + nir_def *buddy_data = nir_shuffle(b, data, buddy); + return nir_bcsel(b, has_buddy, buddy_data, build_identity(b, red_op)); + } + + case nir_intrinsic_inclusive_scan: + return data; + + case nir_intrinsic_reduce: { + /* For reductions, we need to take the top value of the scan */ + nir_def *idx = nir_ufind_msb(b, mask); + return nir_shuffle(b, data, idx); + } + + default: + unreachable("Unsupported scan/reduce op"); + } +} + +static bool +nak_nir_lower_scan_reduce_intrin(nir_builder *b, + nir_intrinsic_instr *intrin, + UNUSED void *_data) +{ + switch (intrin->intrinsic) { + case nir_intrinsic_exclusive_scan: + case nir_intrinsic_inclusive_scan: + case nir_intrinsic_reduce: + break; + default: + return false; + } + + const nir_op red_op = nir_intrinsic_reduction_op(intrin); + + /* Grab the cluster size, defaulting to 32 */ + unsigned cluster_size = 32; + if (nir_intrinsic_has_cluster_size(intrin)) { + cluster_size = nir_intrinsic_cluster_size(intrin); + if (cluster_size == 0 || cluster_size > 32) + cluster_size = 32; + } + + b->cursor = nir_before_instr(&intrin->instr); + + nir_def *data; + if (cluster_size == 1) { + /* Simple case where we're not actually doing any reducing at all. */ + assert(intrin->intrinsic == nir_intrinsic_reduce); + data = intrin->src[0].ssa; + } else { + /* First, we need a mask of all invocations to be included in the + * reduction or scan. For trivial cluster sizes, that's just the mask + * of enabled channels. + */ + nir_def *mask = nir_ballot(b, 1, 32, nir_imm_true(b)); + if (cluster_size < 32) { + nir_def *idx = nir_load_subgroup_invocation(b); + nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1)); + + nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size)); + cluster_mask = nir_ishl(b, cluster_mask, cluster); + + mask = nir_iand(b, mask, cluster_mask); + } + + nir_def *full, *partial; + nir_push_if(b, nir_ieq_imm(b, mask, -1)); + { + full = build_scan_full(b, intrin->intrinsic, red_op, + intrin->src[0].ssa, cluster_size); + } + nir_push_else(b, NULL); + { + partial = build_scan_reduce(b, intrin->intrinsic, red_op, + intrin->src[0].ssa, mask, cluster_size); + } + nir_pop_if(b, NULL); + data = nir_if_phi(b, full, partial); + } + + nir_def_rewrite_uses(&intrin->def, data); + nir_instr_remove(&intrin->instr); + + return true; +} + +bool +nak_nir_lower_scan_reduce(nir_shader *nir) +{ + return nir_shader_intrinsics_pass(nir, nak_nir_lower_scan_reduce_intrin, + nir_metadata_none, NULL); +} diff --git a/src/nouveau/compiler/nak_private.h b/src/nouveau/compiler/nak_private.h index b81b2eaef81..cd07cc42559 100644 --- a/src/nouveau/compiler/nak_private.h +++ b/src/nouveau/compiler/nak_private.h @@ -144,6 +144,7 @@ struct nak_nir_tex_flags { uint32_t pad:26; }; +bool nak_nir_lower_scan_reduce(nir_shader *shader); bool nak_nir_lower_tex(nir_shader *nir, const struct nak_compiler *nak); bool nak_nir_lower_gs_intrinsics(nir_shader *shader);