ac/nir/ngg: Pass wg_repack_result as pointer instead of returning it.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32290>
This commit is contained in:
Timur Kristóf 2024-11-04 11:28:58 +01:00
parent ac78692be4
commit 78f77e161c

View file

@ -304,8 +304,9 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
* Assumes that all invocations in the workgroup are active (exec = -1).
*/
static wg_repack_result
static void
repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
wg_repack_result *results,
nir_def *lds_addr_base, unsigned max_num_waves,
unsigned wave_size)
{
@ -322,11 +323,9 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
/* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
if (max_num_waves == 1) {
wg_repack_result r = {
.num_repacked_invocations = surviving_invocations_in_current_wave,
.repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
};
return r;
results[0].num_repacked_invocations = surviving_invocations_in_current_wave;
results[0].repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0));
return;
}
/* STEP 2. Waves tell each other their number of surviving invocations.
@ -383,16 +382,10 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
nir_def *wg_repacked_index_base =
nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave);
nir_def *wg_num_repacked_invocations =
results[0].num_repacked_invocations =
nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1));
nir_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
wg_repack_result r = {
.num_repacked_invocations = wg_num_repacked_invocations,
.repacked_invocation_index = wg_repacked_index,
};
return r;
results[0].repacked_invocation_index =
nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
}
static nir_def *
@ -1617,9 +1610,9 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
/* Repack the vertices that survived the culling. */
wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, lds_scratch_base,
s->max_num_waves,
s->options->wave_size);
wg_repack_result rep = {0};
repack_invocations_in_workgroup(b, es_accepted, &rep, lds_scratch_base,
s->max_num_waves, s->options->wave_size);
nir_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
nir_def *es_exporter_tid = rep.repacked_invocation_index;
@ -3434,9 +3427,9 @@ ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
* LDS at once, then we only need one barrier instead of one each
* stream..
*/
wg_repack_result rep =
repack_invocations_in_workgroup(b, prim_live[stream], scratch_base,
s->max_num_waves, s->options->wave_size);
wg_repack_result rep = {0};
repack_invocations_in_workgroup(b, prim_live[stream], &rep, scratch_base,
s->max_num_waves, s->options->wave_size);
/* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
* current wave, but still need LDS to sum all wave's count to get workgroup count.
@ -3538,8 +3531,10 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
* To ensure this, we need to repack invocations that have a live vertex.
*/
nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch,
s->max_num_waves, s->options->wave_size);
wg_repack_result rep = {0};
repack_invocations_in_workgroup(b, vertex_live, &rep, s->lds_addr_gs_scratch,
s->max_num_waves, s->options->wave_size);
nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
nir_def *exporter_tid_in_tg = rep.repacked_invocation_index;