nir/opt_intrinsic: optimize quad vote

Optimizes a quadAll()/quadAny() pattern created by dxil-spirv:
7adc87d4de

dxil-spirv can't use clustered reductions because they are not guaranteed
to include helper invocations.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23621>
This commit is contained in:
Rhys Perry 2023-06-13 14:07:53 +01:00 committed by Marge Bot
parent 58f8e0e2a0
commit 8649bde78f
2 changed files with 135 additions and 3 deletions

View file

@ -3658,6 +3658,12 @@ typedef struct nir_shader_compiler_options {
*/
bool optimize_sample_mask_in;
/**
* Optimize boolean reductions of quad broadcasts. This should only be enabled if
* nir_intrinsic_reduce supports INCLUDE_HELPERS.
*/
bool optimize_quad_vote_to_reduce;
bool lower_cs_local_index_to_id;
bool lower_cs_local_id_to_index;

View file

@ -92,9 +92,131 @@ try_opt_bcsel_of_shuffle(nir_builder *b, nir_alu_instr *alu,
return shuffle;
}
static bool
src_is_quad_broadcast(nir_block *block, nir_src src, nir_intrinsic_instr **intrin)
{
nir_intrinsic_instr *broadcast = nir_src_as_intrinsic(src);
if (broadcast == NULL || broadcast->instr.block != block)
return false;
switch (broadcast->intrinsic) {
case nir_intrinsic_quad_broadcast:
if (!nir_src_is_const(broadcast->src[1]))
return false;
FALLTHROUGH;
case nir_intrinsic_quad_swap_horizontal:
case nir_intrinsic_quad_swap_vertical:
case nir_intrinsic_quad_swap_diagonal:
case nir_intrinsic_quad_swizzle_amd:
*intrin = broadcast;
return true;
default:
return false;
}
}
static bool
src_is_alu(nir_op op, nir_src src, nir_src srcs[2])
{
nir_alu_instr *alu = nir_src_as_alu_instr(src);
if (alu == NULL || alu->op != op)
return false;
if (!nir_alu_src_is_trivial_ssa(alu, 0) || !nir_alu_src_is_trivial_ssa(alu, 1))
return false;
srcs[0] = alu->src[0].src;
srcs[1] = alu->src[1].src;
return true;
}
static nir_ssa_def *
try_opt_quad_vote(nir_builder *b, nir_alu_instr *alu, bool block_has_discard)
{
if (block_has_discard)
return NULL;
if (!nir_alu_src_is_trivial_ssa(alu, 0) || !nir_alu_src_is_trivial_ssa(alu, 1))
return NULL;
nir_intrinsic_instr *quad_broadcasts[4];
nir_src srcs[2][2];
bool found = false;
/* Match (broadcast0 op broadcast1) op (broadcast2 op broadcast3). */
found = src_is_alu(alu->op, alu->src[0].src, srcs[0]) &&
src_is_alu(alu->op, alu->src[1].src, srcs[1]) &&
src_is_quad_broadcast(alu->instr.block, srcs[0][0], &quad_broadcasts[0]) &&
src_is_quad_broadcast(alu->instr.block, srcs[0][1], &quad_broadcasts[1]) &&
src_is_quad_broadcast(alu->instr.block, srcs[1][0], &quad_broadcasts[2]) &&
src_is_quad_broadcast(alu->instr.block, srcs[1][1], &quad_broadcasts[3]);
/* Match ((broadcast2 op broadcast3) op broadcast1) op broadcast0). */
if (!found) {
if ((src_is_alu(alu->op, alu->src[0].src, srcs[0]) &&
src_is_quad_broadcast(alu->instr.block, alu->src[1].src, &quad_broadcasts[0])) ||
(src_is_alu(alu->op, alu->src[1].src, srcs[0]) &&
src_is_quad_broadcast(alu->instr.block, alu->src[0].src, &quad_broadcasts[0]))) {
/* ((broadcast2 || broadcast3) || broadcast1) */
if ((src_is_alu(alu->op, srcs[0][0], srcs[1]) &&
src_is_quad_broadcast(alu->instr.block, srcs[0][1], &quad_broadcasts[1])) ||
(src_is_alu(alu->op, srcs[0][1], srcs[1]) &&
src_is_quad_broadcast(alu->instr.block, srcs[0][0], &quad_broadcasts[1]))) {
/* (broadcast2 || broadcast3) */
found = src_is_quad_broadcast(alu->instr.block, srcs[1][0], &quad_broadcasts[2]) &&
src_is_quad_broadcast(alu->instr.block, srcs[1][1], &quad_broadcasts[3]);
}
}
}
if (!found)
return NULL;
/* Check if each lane in a quad reduces all lanes in the quad, and if all broadcasts read the
* same data.
*/
uint16_t lanes_read = 0;
for (unsigned i = 0; i < 4; i++) {
if (!nir_srcs_equal(quad_broadcasts[i]->src[0], quad_broadcasts[0]->src[0]))
return NULL;
for (unsigned j = 0; j < 4; j++) {
unsigned lane;
switch (quad_broadcasts[i]->intrinsic) {
case nir_intrinsic_quad_broadcast:
lane = nir_src_as_uint(quad_broadcasts[i]->src[1]) & 0x3;
break;
case nir_intrinsic_quad_swap_horizontal:
lane = j ^ 1;
break;
case nir_intrinsic_quad_swap_vertical:
lane = j ^ 2;
break;
case nir_intrinsic_quad_swap_diagonal:
lane = 3 - j;
break;
case nir_intrinsic_quad_swizzle_amd:
lane = (nir_intrinsic_swizzle_mask(quad_broadcasts[i]) >> (j * 2)) & 0x3;
break;
default:
unreachable();
}
lanes_read |= (1 << lane) << (j * 4);
}
}
if (lanes_read != 0xffff)
return NULL;
/* Create reduction. */
return nir_reduce(b, quad_broadcasts[0]->src[0].ssa, .reduction_op = alu->op, .cluster_size = 4,
.include_helpers = true);
}
static bool
opt_intrinsics_alu(nir_builder *b, nir_alu_instr *alu,
bool block_has_discard)
bool block_has_discard, const struct nir_shader_compiler_options *options)
{
nir_ssa_def *replacement = NULL;
@ -102,7 +224,11 @@ opt_intrinsics_alu(nir_builder *b, nir_alu_instr *alu,
case nir_op_bcsel:
replacement = try_opt_bcsel_of_shuffle(b, alu, block_has_discard);
break;
case nir_op_iand:
case nir_op_ior:
if (nir_dest_bit_size(alu->dest.dest) == 1 && options->optimize_quad_vote_to_reduce)
replacement = try_opt_quad_vote(b, alu, block_has_discard);
break;
default:
break;
}
@ -181,7 +307,7 @@ opt_intrinsics_impl(nir_function_impl *impl,
switch (instr->type) {
case nir_instr_type_alu:
if (opt_intrinsics_alu(&b, nir_instr_as_alu(instr),
block_has_discard))
block_has_discard, options))
progress = true;
break;