diff --git a/src/compiler/nir/nir_opt_uniform_subgroup.c b/src/compiler/nir/nir_opt_uniform_subgroup.c index b4b39b1e6fb..f907092bc51 100644 --- a/src/compiler/nir/nir_opt_uniform_subgroup.c +++ b/src/compiler/nir/nir_opt_uniform_subgroup.c @@ -94,6 +94,14 @@ parse_select_of_con_values(nir_builder *b, nir_def *def, struct select_info *inf } } +static nir_def * +get_ballot(nir_builder *b, nir_def *cond, + const nir_lower_subgroups_options *options) +{ + return nir_ballot(b, options->ballot_components, + options->ballot_bit_size, cond ? cond : nir_imm_true(b)); +} + static nir_def * ballot_bit_count(nir_builder *b, nir_def *ballot) { @@ -103,25 +111,34 @@ ballot_bit_count(nir_builder *b, nir_def *ballot) } static nir_def * -count_active_invocations(nir_builder *b, nir_def *value, bool inclusive, +count_active_invocations(nir_builder *b, nir_def *cond, bool inclusive, const nir_lower_subgroups_options *options) { + nir_def *value = get_ballot(b, cond, options); + /* For the non-inclusive case, the two paths are functionally the same. * For the inclusive case, the are similar but very subtly different. * * The bit_count path will mask "value" with the subgroup LE mask instead * of the subgroup LT mask. This is the definition of the inclusive count. * - * AMD's mbcnt instruction always uses the subgroup LT mask. To perform the - * inclusive count using mbcnt, two assumptions are made. First, trivially, - * the current invocation is active. Second, the bit for the current - * invocation in "value" is set. Since "value" is assumed to be the result - * of ballot(true), the second condition will also be met. + * AMD's mbcnt instruction always uses the subgroup LT mask. * - * When those conditions are met, the inclusive count is the exclusive - * count plus one. + * When we know the condition is true, the bit for the current + * invocation value[N] is 1. Therefore we can count value[0:N-1] and + * only need to add 1 for the inclusive count. + * + * When we can't make any assumption about the active invocations' bits + * because the condition is not known true, transform the inclusive case + * to an exclusive count by counting value[1:N] and adding value[0] + * in the accumulator. + * The additional operations here can use the uniform ALU. */ - if (options->lower_ballot_bit_count_to_mbcnt_amd) { + if (options->lower_ballot_bit_count_to_mbcnt_amd && inclusive && cond) { + nir_def *first_bit = nir_iand_imm(b, nir_u2u32(b, value), 1); + value = nir_ushr_imm(b, value, 1); + return nir_mbcnt_amd(b, value, first_bit); + } else if (options->lower_ballot_bit_count_to_mbcnt_amd) { return nir_mbcnt_amd(b, value, nir_imm_int(b, (int)inclusive)); } else { nir_def *mask = @@ -134,6 +151,31 @@ count_active_invocations(nir_builder *b, nir_def *value, bool inclusive, } } +static nir_def * +conditional_add_xor_reduce(nir_builder *b, nir_intrinsic_instr *intrin, nir_def *cond, nir_def *src, + const nir_lower_subgroups_options *options) +{ + const nir_op reduction_op = (nir_op)nir_intrinsic_reduction_op(intrin); + nir_def *count; + + if (intrin->intrinsic == nir_intrinsic_reduce) { + count = ballot_bit_count(b, get_ballot(b, cond, options)); + } else { + count = count_active_invocations(b, cond, + intrin->intrinsic == nir_intrinsic_inclusive_scan, + options); + } + + if (reduction_op == nir_op_iadd) { + return nir_imul(b, nir_u2uN(b, count, src->bit_size), src); + } else if (reduction_op == nir_op_fadd) { + return nir_fmul(b, nir_u2fN(b, count, src->bit_size), src); + } else { + count = nir_iand(b, count, nir_imm_int(b, 1)); + return nir_imul(b, nir_u2uN(b, count, src->bit_size), src); + } +} + static bool opt_uniform_subgroup_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *_state) { @@ -188,38 +230,30 @@ opt_uniform_subgroup_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *_s case nir_op_iadd: case nir_op_fadd: case nir_op_ixor: { - if (nir_src_is_divergent(&intrin->src[0])) - return false; if (nir_intrinsic_has_cluster_size(intrin) && nir_intrinsic_cluster_size(intrin)) return false; - nir_def *count; - nir_def *ballot = nir_ballot(b, options->ballot_components, - options->ballot_bit_size, nir_imm_true(b)); - - if (intrin->intrinsic == nir_intrinsic_reduce) { - count = ballot_bit_count(b, ballot); + if (!nir_src_is_divergent(&intrin->src[0])) { + replacement = conditional_add_xor_reduce(b, intrin, NULL, intrin->src[0].ssa, options); } else { - count = count_active_invocations(b, ballot, - intrin->intrinsic == nir_intrinsic_inclusive_scan, - options); + /* Ballot must be scalar. */ + if (intrin->def.num_components != 1) + return false; + + struct select_info sel; + if (!parse_select_of_con_values(b, intrin->src[0].ssa, &sel)) + return false; + + nir_def *parts[2]; + + for (unsigned i = 0; i < 2; i++) { + nir_def *cond = i ? nir_inot(b, sel.cond) : sel.cond; + parts[i] = conditional_add_xor_reduce(b, intrin, cond, sel.values[i], options); + } + + replacement = nir_build_alu2(b, reduction_op, parts[0], parts[1]); } - const unsigned bit_size = intrin->src[0].ssa->bit_size; - - if (reduction_op == nir_op_iadd) { - replacement = nir_imul(b, nir_u2uN(b, count, bit_size), - intrin->src[0].ssa); - } else if (reduction_op == nir_op_fadd) { - replacement = nir_fmul(b, nir_u2fN(b, count, bit_size), - intrin->src[0].ssa); - } else { - replacement = nir_imul(b, - nir_u2uN(b, - nir_iand(b, count, nir_imm_int(b, 1)), - bit_size), - intrin->src[0].ssa); - } break; }