microsoft/compiler: For emulating scan, ensure all threads are active when reading cross-lane

HLSL docs say WaveReadLaneAt is undefined if the target lane is inactive. This makes
sense since the target lane may need to *send* the data, rather than it being pulled
by the calling lane. So don't early-out on the loop, iterate through the whole wave
on all threads and read the cross-lane data before branching.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27624>
This commit is contained in:
Jesse Natalie 2024-02-14 14:24:55 -08:00 committed by Marge Bot
parent 219be55807
commit 0daad70f9f

View file

@ -2178,6 +2178,7 @@ lower_subgroup_scan(nir_builder *b, nir_intrinsic_instr *intr, void *data)
b->cursor = nir_before_instr(&intr->instr);
nir_op op = nir_intrinsic_reduction_op(intr);
nir_def *subgroup_id = nir_load_subgroup_invocation(b);
nir_def *subgroup_size = nir_load_subgroup_size(b);
nir_def *active_threads = nir_ballot(b, 4, 32, nir_imm_true(b));
nir_def *base_value;
uint32_t bit_size = intr->def.bit_size;
@ -2203,18 +2204,23 @@ lower_subgroup_scan(nir_builder *b, nir_intrinsic_instr *intr, void *data)
nir_store_var(b, result_var, base_value, 1);
nir_loop *loop = nir_push_loop(b);
nir_def *loop_counter = nir_load_var(b, loop_counter_var);
nir_if *nif = nir_push_if(b, intr->intrinsic == nir_intrinsic_inclusive_scan ?
nir_if *nif = nir_push_if(b, nir_ilt(b, loop_counter, subgroup_size));
nir_def *other_thread_val = nir_read_invocation(b, intr->src[0].ssa, loop_counter);
nir_def *thread_in_range = intr->intrinsic == nir_intrinsic_inclusive_scan ?
nir_ige(b, subgroup_id, loop_counter) :
nir_ilt(b, loop_counter, subgroup_id));
nir_if *if_active_thread = nir_push_if(b, nir_ballot_bitfield_extract(b, 32, active_threads, loop_counter));
nir_def *result = nir_build_alu2(b, op,
nir_load_var(b, result_var),
nir_read_invocation(b, intr->src[0].ssa, loop_counter));
nir_ilt(b, loop_counter, subgroup_id);
nir_def *thread_active = nir_ballot_bitfield_extract(b, 1, active_threads, loop_counter);
nir_if *if_active_thread = nir_push_if(b, nir_iand(b, thread_in_range, thread_active));
nir_def *result = nir_build_alu2(b, op, nir_load_var(b, result_var), other_thread_val);
nir_store_var(b, result_var, result, 1);
nir_pop_if(b, if_active_thread);
nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
nir_jump(b, nir_jump_continue);
nir_pop_if(b, nif);
nir_jump(b, nir_jump_break);
nir_pop_loop(b, loop);