diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index b8022870ba7..b8c1ee77f7f 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5137,6 +5137,7 @@ typedef struct nir_lower_subgroups_options { bool lower_elect:1; bool lower_read_invocation_to_cond:1; bool lower_rotate_to_shuffle:1; + bool lower_ballot_bit_count_to_mbcnt_amd:1; } nir_lower_subgroups_options; bool nir_lower_subgroups(nir_shader *shader, diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 75a8ecd22b8..9ace7f25481 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -774,6 +774,20 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_ballot_bit_count_exclusive: case nir_intrinsic_ballot_bit_count_inclusive: { + assert(intrin->src[0].is_ssa); + nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, + options); + if (options->lower_ballot_bit_count_to_mbcnt_amd) { + nir_ssa_def *acc; + if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_exclusive) { + acc = nir_imm_int(b, 0); + } else { + acc = nir_iand_imm(b, nir_u2u32(b, int_val), 0x1); + int_val = nir_ushr_imm(b, int_val, 1); + } + return nir_mbcnt_amd(b, int_val, acc); + } + nir_ssa_def *mask; if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) { mask = nir_inot(b, build_subgroup_gt_mask(b, options)); @@ -781,10 +795,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) mask = nir_inot(b, build_subgroup_ge_mask(b, options)); } - assert(intrin->src[0].is_ssa); - nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, - options); - return vec_bit_count(b, nir_iand(b, int_val, mask)); }