diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index e818f396d0e..99b2526b618 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6390,7 +6390,8 @@ bool nir_lower_undef_to_zero(nir_shader *shader); bool nir_opt_uniform_atomics(nir_shader *shader); -bool nir_opt_uniform_subgroup(nir_shader *shader); +bool nir_opt_uniform_subgroup(nir_shader *shader, + const nir_lower_subgroups_options *); bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter, void *data); diff --git a/src/compiler/nir/nir_opt_uniform_subgroup.c b/src/compiler/nir/nir_opt_uniform_subgroup.c index 7739080fcc7..6018b1d0156 100644 --- a/src/compiler/nir/nir_opt_uniform_subgroup.c +++ b/src/compiler/nir/nir_opt_uniform_subgroup.c @@ -32,6 +32,7 @@ opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state) return !nir_src_is_divergent(intrin->src[0]); case nir_intrinsic_reduce: + case nir_intrinsic_exclusive_scan: case nir_intrinsic_inclusive_scan: { if (nir_src_is_divergent(intrin->src[0])) return false; @@ -39,6 +40,11 @@ opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state) const nir_op reduction_op = (nir_op) nir_intrinsic_reduction_op(intrin); switch (reduction_op) { + case nir_op_iadd: + case nir_op_fadd: + case nir_op_ixor: + return true; + case nir_op_imin: case nir_op_umin: case nir_op_fmin: @@ -47,9 +53,8 @@ opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state) case nir_op_fmax: case nir_op_iand: case nir_op_ior: - return true; + return intrin->intrinsic != nir_intrinsic_exclusive_scan; - /* FINISHME: iadd, ixor, and fadd are also possible. */ default: return false; } @@ -60,21 +65,94 @@ opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state) } } +static nir_def * +count_active_invocations(nir_builder *b, nir_def *value, bool inclusive, + bool has_mbcnt_amd) +{ + /* For the non-inclusive case, the two paths are functionally the same. + * For the inclusive case, the are similar but very subtly different. + * + * The bit_count path will mask "value" with the subgroup LE mask instead + * of the subgroup LT mask. This is the definition of the inclusive count. + * + * AMD's mbcnt instruction always uses the subgroup LT mask. To perform the + * inclusive count using mbcnt, two assumptions are made. First, trivially, + * the current invocation is active. Second, the bit for the current + * invocation in "value" is set. Since "value" is assumed to be the result + * of ballot(true), the second condition will also be met. + * + * When those conditions are met, the inclusive count is the exclusive + * count plus one. + */ + if (has_mbcnt_amd) { + return nir_mbcnt_amd(b, value, nir_imm_int(b, (int) inclusive)); + } else { + nir_def *mask = inclusive + ? nir_load_subgroup_le_mask(b, 1, 32) + : nir_load_subgroup_lt_mask(b, 1, 32); + + return nir_bit_count(b, nir_iand(b, value, mask)); + } +} + static nir_def * opt_uniform_subgroup_instr(nir_builder *b, nir_instr *instr, void *_state) { + const nir_lower_subgroups_options *options = (nir_lower_subgroups_options *) _state; nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic == nir_intrinsic_reduce || + intrin->intrinsic == nir_intrinsic_inclusive_scan || + intrin->intrinsic == nir_intrinsic_exclusive_scan) { + const nir_op reduction_op = (nir_op) nir_intrinsic_reduction_op(intrin); + + if (reduction_op == nir_op_iadd || + reduction_op == nir_op_fadd || + reduction_op == nir_op_ixor) { + nir_def *count; + + nir_def *ballot = nir_ballot(b, options->ballot_components, + options->ballot_bit_size, nir_imm_true(b)); + + if (intrin->intrinsic == nir_intrinsic_reduce) { + count = nir_bit_count(b, ballot); + } else { + count = count_active_invocations(b, ballot, + intrin->intrinsic == nir_intrinsic_inclusive_scan, + false); + } + + const unsigned bit_size = intrin->src[0].ssa->bit_size; + + if (reduction_op == nir_op_iadd) { + return nir_imul(b, + nir_u2uN(b, count, bit_size), + intrin->src[0].ssa); + } else if (reduction_op == nir_op_fadd) { + return nir_fmul(b, + nir_u2fN(b, count, bit_size), + intrin->src[0].ssa); + } else { + return nir_imul(b, + nir_u2uN(b, + nir_iand(b, count, nir_imm_int(b, 1)), + bit_size), + intrin->src[0].ssa); + } + } + } + return intrin->src[0].ssa; } bool -nir_opt_uniform_subgroup(nir_shader *shader) +nir_opt_uniform_subgroup(nir_shader *shader, + const nir_lower_subgroups_options *options) { bool progress = nir_shader_lower_instructions(shader, opt_uniform_subgroup_filter, opt_uniform_subgroup_instr, - NULL); + (void *) options); return progress; } diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index 144e540bb40..e195612045b 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -1699,12 +1699,14 @@ brw_postprocess_nir(nir_shader *nir, const struct brw_compiler *compiler, NIR_PASS(_, nir, nir_convert_to_lcssa, true, true); NIR_PASS_V(nir, nir_divergence_analysis); + static const nir_lower_subgroups_options subgroups_options = { + .ballot_bit_size = 32, + .ballot_components = 1, + .lower_elect = true, + .lower_subgroup_masks = true, + }; + if (OPT(nir_opt_uniform_atomics)) { - const nir_lower_subgroups_options subgroups_options = { - .ballot_bit_size = 32, - .ballot_components = 1, - .lower_elect = true, - }; OPT(nir_lower_subgroups, &subgroups_options); if (OPT(nir_lower_int64)) @@ -1716,12 +1718,13 @@ brw_postprocess_nir(nir_shader *nir, const struct brw_compiler *compiler, /* nir_opt_uniform_subgroup can create some operations (e.g., * load_subgroup_lt_mask) that need to be lowered again. */ - if (OPT(nir_opt_uniform_subgroup)) { - const nir_lower_subgroups_options subgroups_options = { - .ballot_bit_size = 32, - .ballot_components = 1, - .lower_subgroup_masks = true, - }; + if (OPT(nir_opt_uniform_subgroup, &subgroups_options)) { + /* Some of the optimizations can generate 64-bit integer multiplication + * that must be lowered. + */ + if (OPT(nir_lower_int64)) + brw_nir_optimize(nir, is_scalar, devinfo); + OPT(nir_lower_subgroups, &subgroups_options); }