nir/opt_uniform_subgroup: use ballot_bit_count

Using bit_count on the result of ballot doesn't work for targets where
ballot's num_components > 1.

Signed-off-by: Job Noorman <jnoorman@igalia.com>
Reviewed-by: Emma Anholt <emma@anholt.net>
Fixes: d2e1e4442a ("ir3: enable nir_opt_uniform_subgroup")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35669>
This commit is contained in:
Job Noorman 2025-07-03 06:27:09 +02:00 committed by Marge Bot
parent 2091d199db
commit ae66bd1c00
2 changed files with 18 additions and 10 deletions

View file

@ -69,9 +69,18 @@ opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state)
}
}
static nir_def *
ballot_bit_count(nir_builder *b, nir_def *ballot)
{
return ballot->num_components == 1
? nir_bit_count(b, ballot)
: nir_ballot_bit_count_reduce(b, ballot->bit_size, ballot);
}
static nir_def *
count_active_invocations(nir_builder *b, nir_def *value, bool inclusive,
bool has_mbcnt_amd)
bool has_mbcnt_amd,
const nir_lower_subgroups_options *options)
{
/* For the non-inclusive case, the two paths are functionally the same.
* For the inclusive case, the are similar but very subtly different.
@ -91,11 +100,13 @@ count_active_invocations(nir_builder *b, nir_def *value, bool inclusive,
if (has_mbcnt_amd) {
return nir_mbcnt_amd(b, value, nir_imm_int(b, (int)inclusive));
} else {
nir_def *mask = inclusive
? nir_load_subgroup_le_mask(b, 1, 32)
: nir_load_subgroup_lt_mask(b, 1, 32);
nir_def *mask =
inclusive ? nir_load_subgroup_le_mask(b, options->ballot_components,
options->ballot_bit_size)
: nir_load_subgroup_lt_mask(b, options->ballot_components,
options->ballot_bit_size);
return nir_bit_count(b, nir_iand(b, value, mask));
return ballot_bit_count(b, nir_iand(b, value, mask));
}
}
@ -119,11 +130,11 @@ opt_uniform_subgroup_instr(nir_builder *b, nir_instr *instr, void *_state)
options->ballot_bit_size, nir_imm_true(b));
if (intrin->intrinsic == nir_intrinsic_reduce) {
count = nir_bit_count(b, ballot);
count = ballot_bit_count(b, ballot);
} else {
count = count_active_invocations(b, ballot,
intrin->intrinsic == nir_intrinsic_inclusive_scan,
false);
false, options);
}
const unsigned bit_size = intrin->src[0].ssa->bit_size;

View file

@ -1,8 +1,5 @@
test_shader_instructions,Fail
test_line_rasterization,Fail
# error: src->ssa->num_components == num_components (../src/compiler/nir/nir_validate.c:205)
test_shader_waveop_maximal_convergence,Crash
# Hangs
test_fence_wait_robustness,Crash