nir/subgroups: Add option to lower Boolean subgroup reductions

This will be useful for AMD, and probably Intel as well.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/218>
This commit is contained in:
Connor Abbott 2019-02-04 12:55:32 +01:00 committed by Marge Bot
parent 387e698bde
commit 1dab2c5bd2
2 changed files with 151 additions and 10 deletions

View file

@ -5597,6 +5597,7 @@ typedef struct nir_lower_subgroups_options {
bool lower_rotate_to_shuffle : 1;
bool lower_ballot_bit_count_to_mbcnt_amd : 1;
bool lower_inverse_ballot : 1;
bool lower_boolean_reduce : 1;
} nir_lower_subgroups_options;
bool nir_lower_subgroups(nir_shader *shader,

View file

@ -355,6 +355,152 @@ lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin)
return nir_load_var(b, result);
}
static nir_def *
vec_bit_count(nir_builder *b, nir_def *value)
{
nir_def *vec_result = nir_bit_count(b, value);
nir_def *result = nir_channel(b, vec_result, 0);
for (unsigned i = 1; i < value->num_components; i++)
result = nir_iadd(b, result, nir_channel(b, vec_result, i));
return result;
}
/* produce a bitmask of 111...000...111... alternating between "size"
* 1's and "size" 0's (the LSB is 1).
*/
static uint64_t
reduce_mask(unsigned size, unsigned ballot_bit_size)
{
uint64_t mask = 0;
for (unsigned i = 0; i < ballot_bit_size; i += 2 * size) {
mask |= ((1ull << size) - 1) << i;
}
return mask;
}
/* operate on a uniform per-thread bitmask provided by ballot() to perform the
* desired Boolean reduction. Assumes that the identity of the operation is
* false (so, no iand).
*/
static nir_def *
lower_boolean_reduce_internal(nir_builder *b, nir_def *src,
unsigned cluster_size, nir_op op,
const nir_lower_subgroups_options *options)
{
for (unsigned size = 1; size < cluster_size; size *= 2) {
nir_def *shifted = nir_ushr_imm(b, src, size);
src = nir_build_alu2(b, op, shifted, src);
uint64_t mask = reduce_mask(size, options->ballot_bit_size);
src = nir_iand_imm(b, src, mask);
shifted = nir_ishl_imm(b, src, size);
src = nir_ior(b, src, shifted);
}
return src;
}
/* operate on a uniform per-thread bitmask provided by ballot() to perform the
* desired Boolean inclusive scan. Assumes that the identity of the operation is
* false (so, no iand).
*/
static nir_def *
lower_boolean_scan_internal(nir_builder *b, nir_def *src,
nir_op op,
const nir_lower_subgroups_options *options)
{
if (op == nir_op_ior) {
/* We want to return a bitmask with all 1's starting at the first 1 in
* src. -src is equivalent to ~src + 1. While src | ~src returns all
* 1's, src | (~src + 1) returns all 1's except for the bits changed by
* the increment. Any 1's before the least significant 0 of ~src are
* turned into 0 (zeroing those bits after or'ing) and the least
* signficant 0 of ~src is turned into 1 (not doing anything). So the
* final output is what we want.
*/
return nir_ior(b, src, nir_ineg(b, src));
} else {
assert(op == nir_op_ixor);
for (unsigned shift = 1; shift < options->ballot_bit_size; shift *= 2) {
src = nir_ixor(b, src, nir_ishl_imm(b, src, shift));
}
return src;
}
}
static nir_def *
lower_boolean_reduce(nir_builder *b, nir_intrinsic_instr *intrin,
const nir_lower_subgroups_options *options)
{
assert(intrin->num_components == 1);
assert(options->ballot_components == 1);
unsigned cluster_size =
intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
nir_op op = nir_intrinsic_reduction_op(intrin);
/* For certain cluster sizes, reductions of iand and ior can be implemented
* more efficiently.
*/
if (intrin->intrinsic == nir_intrinsic_reduce) {
if (cluster_size == 0) {
if (op == nir_op_iand)
return nir_vote_all(b, 1, intrin->src[0].ssa);
else if (op == nir_op_ior)
return nir_vote_any(b, 1, intrin->src[0].ssa);
else if (op == nir_op_ixor)
return nir_i2b(b, nir_iand_imm(b, vec_bit_count(b, nir_ballot(b,
options->ballot_components,
options->ballot_bit_size,
intrin->src[0].ssa)),
1));
else
unreachable("bad boolean reduction op");
}
if (cluster_size == 4) {
if (op == nir_op_iand)
return nir_quad_vote_all(b, 1, intrin->src[0].ssa);
else if (op == nir_op_ior)
return nir_quad_vote_any(b, 1, intrin->src[0].ssa);
}
}
nir_def *src = intrin->src[0].ssa;
/* Apply DeMorgan's law to implement "and" reductions, since all the
* lower_boolean_*_internal() functions assume an identity of 0 to make the
* generated code shorter.
*/
nir_op new_op = (op == nir_op_iand) ? nir_op_ior : op;
if (op == nir_op_iand) {
src = nir_inot(b, src);
}
nir_def *val = nir_ballot(b, options->ballot_components, options->ballot_bit_size, src);
switch (intrin->intrinsic) {
case nir_intrinsic_reduce:
val = lower_boolean_reduce_internal(b, val, cluster_size, new_op, options);
break;
case nir_intrinsic_inclusive_scan:
val = lower_boolean_scan_internal(b, val, new_op, options);
break;
case nir_intrinsic_exclusive_scan:
val = lower_boolean_scan_internal(b, val, new_op, options);
val = nir_ishl_imm(b, val, 1);
break;
default:
unreachable("bad intrinsic");
}
if (op == nir_op_iand) {
val = nir_inot(b, val);
}
return nir_inverse_ballot(b, 1, val);
}
static bool
lower_subgroups_filter(const nir_instr *instr, const void *_options)
{
@ -486,16 +632,6 @@ build_subgroup_mask(nir_builder *b,
result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
}
static nir_def *
vec_bit_count(nir_builder *b, nir_def *value)
{
nir_def *vec_result = nir_bit_count(b, value);
nir_def *result = nir_channel(b, vec_result, 0);
for (unsigned i = 1; i < value->num_components; i++)
result = nir_iadd(b, result, nir_channel(b, vec_result, i));
return result;
}
static nir_def *
vec_find_lsb(nir_builder *b, nir_def *value)
{
@ -825,12 +961,16 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
return intrin->src[0].ssa;
if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin);
if (options->lower_boolean_reduce && intrin->def.bit_size == 1)
return lower_boolean_reduce(b, intrin, options);
return ret;
}
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_exclusive_scan:
if (options->lower_to_scalar && intrin->num_components > 1)
return lower_subgroup_op_to_scalar(b, intrin);
if (options->lower_boolean_reduce && intrin->def.bit_size == 1)
return lower_boolean_reduce(b, intrin, options);
break;
case nir_intrinsic_rotate: