nak: Lower scan/reduce in NIR

We can probably do slightly better than this if we take advantage of the
predicate destination in SHFL but not by much.  All of the insanity is
still required (nvidia basically emits this), we just might be able to
save ourslves a few comparison ops.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26264>
This commit is contained in:
Faith Ekstrand 2023-11-17 13:30:08 -06:00 committed by Marge Bot
parent 11bcce9461
commit cca40086c6
4 changed files with 200 additions and 0 deletions

View file

@ -22,6 +22,7 @@ libnak_c_files = files(
'nak.h',
'nak_nir.c',
'nak_nir_add_barriers.c',
'nak_nir_lower_scan_reduce.c',
'nak_nir_lower_tex.c',
'nak_nir_lower_vtg_io.c',
'nak_nir_lower_gs_intrinsics.c',

View file

@ -297,6 +297,7 @@ nak_preprocess_nir(nir_shader *nir, const struct nak_compiler *nak)
.lower_inverse_ballot = true,
};
OPT(nir, nir_lower_subgroups, &subgroups_options);
OPT(nir, nak_nir_lower_scan_reduce);
}
static uint16_t

View file

@ -0,0 +1,197 @@
/*
* Copyright © 2023 Collabora, Ltd.
* SPDX-License-Identifier: MIT
*/
#include "nak_private.h"
#include "nir_builder.h"
static nir_def *
build_identity(nir_builder *b, nir_op op)
{
nir_const_value ident_const = nir_alu_binop_identity(op, 32);
return nir_build_imm(b, 1, 32, &ident_const);
}
/* Implementation of scan/reduce that assumes a full subgroup */
static nir_def *
build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
nir_def *data, unsigned cluster_size)
{
switch (op) {
case nir_intrinsic_exclusive_scan:
case nir_intrinsic_inclusive_scan: {
for (unsigned i = 1; i < cluster_size; i *= 2) {
nir_def *idx = nir_load_subgroup_invocation(b);
nir_def *has_buddy = nir_ige_imm(b, idx, i);
nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
data = nir_bcsel(b, has_buddy, accum, data);
}
if (op == nir_intrinsic_exclusive_scan) {
/* For exclusive scans, we need to shift one more time and fill in the
* bottom channel with identity.
*/
assert(cluster_size == 32);
nir_def *idx = nir_load_subgroup_invocation(b);
nir_def *has_buddy = nir_ige_imm(b, idx, 1);
nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
data = nir_bcsel(b, has_buddy, buddy_data, build_identity(b, red_op));
}
return data;
}
case nir_intrinsic_reduce: {
for (unsigned i = 1; i < cluster_size; i *= 2) {
nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
data = nir_build_alu2(b, red_op, data, buddy_data);
}
return data;
}
default:
unreachable("Unsupported scan/reduce op");
}
}
/* Fully generic implementation of scan/reduce that takes a mask */
static nir_def *
build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
nir_def *data, nir_def *mask, unsigned max_mask_bits)
{
nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, 32);
/* Mask of all channels whose values we need to accumulate. Our own value
* is already in accum, if inclusive, thanks to the initialization above.
* We only need to consider lower indexed invocations.
*/
nir_def *remaining = nir_iand(b, mask, lt_mask);
for (unsigned i = 1; i < max_mask_bits; i *= 2) {
/* At each step, our buddy channel is the first channel we have yet to
* take into account in the accumulator.
*/
nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
nir_def *buddy = nir_ufind_msb(b, remaining);
/* Accumulate with our buddy channel, if any */
nir_def *buddy_data = nir_shuffle(b, data, buddy);
nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
data = nir_bcsel(b, has_buddy, accum, data);
/* We just took into account everything in our buddy's accumulator from
* the previous step. The only things remaining are whatever channels
* were remaining for our buddy.
*/
nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
}
switch (op) {
case nir_intrinsic_exclusive_scan: {
/* For exclusive scans, we need to shift one more time and fill in the
* bottom channel with identity.
*
* Some of this will get CSE'd with the first step but that's okay. The
* code is cleaner this way.
*/
nir_def *lower = nir_iand(b, mask, lt_mask);
nir_def *has_buddy = nir_ine_imm(b, lower, 0);
nir_def *buddy = nir_ufind_msb(b, lower);
nir_def *buddy_data = nir_shuffle(b, data, buddy);
return nir_bcsel(b, has_buddy, buddy_data, build_identity(b, red_op));
}
case nir_intrinsic_inclusive_scan:
return data;
case nir_intrinsic_reduce: {
/* For reductions, we need to take the top value of the scan */
nir_def *idx = nir_ufind_msb(b, mask);
return nir_shuffle(b, data, idx);
}
default:
unreachable("Unsupported scan/reduce op");
}
}
static bool
nak_nir_lower_scan_reduce_intrin(nir_builder *b,
nir_intrinsic_instr *intrin,
UNUSED void *_data)
{
switch (intrin->intrinsic) {
case nir_intrinsic_exclusive_scan:
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_reduce:
break;
default:
return false;
}
const nir_op red_op = nir_intrinsic_reduction_op(intrin);
/* Grab the cluster size, defaulting to 32 */
unsigned cluster_size = 32;
if (nir_intrinsic_has_cluster_size(intrin)) {
cluster_size = nir_intrinsic_cluster_size(intrin);
if (cluster_size == 0 || cluster_size > 32)
cluster_size = 32;
}
b->cursor = nir_before_instr(&intrin->instr);
nir_def *data;
if (cluster_size == 1) {
/* Simple case where we're not actually doing any reducing at all. */
assert(intrin->intrinsic == nir_intrinsic_reduce);
data = intrin->src[0].ssa;
} else {
/* First, we need a mask of all invocations to be included in the
* reduction or scan. For trivial cluster sizes, that's just the mask
* of enabled channels.
*/
nir_def *mask = nir_ballot(b, 1, 32, nir_imm_true(b));
if (cluster_size < 32) {
nir_def *idx = nir_load_subgroup_invocation(b);
nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
cluster_mask = nir_ishl(b, cluster_mask, cluster);
mask = nir_iand(b, mask, cluster_mask);
}
nir_def *full, *partial;
nir_push_if(b, nir_ieq_imm(b, mask, -1));
{
full = build_scan_full(b, intrin->intrinsic, red_op,
intrin->src[0].ssa, cluster_size);
}
nir_push_else(b, NULL);
{
partial = build_scan_reduce(b, intrin->intrinsic, red_op,
intrin->src[0].ssa, mask, cluster_size);
}
nir_pop_if(b, NULL);
data = nir_if_phi(b, full, partial);
}
nir_def_rewrite_uses(&intrin->def, data);
nir_instr_remove(&intrin->instr);
return true;
}
bool
nak_nir_lower_scan_reduce(nir_shader *nir)
{
return nir_shader_intrinsics_pass(nir, nak_nir_lower_scan_reduce_intrin,
nir_metadata_none, NULL);
}

View file

@ -144,6 +144,7 @@ struct nak_nir_tex_flags {
uint32_t pad:26;
};
bool nak_nir_lower_scan_reduce(nir_shader *shader);
bool nak_nir_lower_tex(nir_shader *nir, const struct nak_compiler *nak);
bool nak_nir_lower_gs_intrinsics(nir_shader *shader);