mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-08 15:38:09 +02:00
ac/nir/ngg: Workgroup scan over two bools.
Implement two workgroup scans over two boolean values in parallel, so that they can be done with very minimal ALU overhead. Signed-off-by: Timur Kristóf <timur.kristof@gmail.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32290>
This commit is contained in:
parent
78f77e161c
commit
492d8f3778
1 changed files with 76 additions and 34 deletions
|
|
@ -245,11 +245,12 @@ typedef struct {
|
|||
* (Other lanes are not deactivated but their calculation is not used.)
|
||||
*/
|
||||
static nir_def *
|
||||
summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords)
|
||||
summarize_repack(nir_builder *b, nir_def *packed_counts, bool mask_lane_id, unsigned num_lds_dwords)
|
||||
{
|
||||
/* We'll use shift to filter out the bytes not needed by the current lane.
|
||||
*
|
||||
* Need to shift by: `num_lds_dwords * 4 - 1 - lane_id` (in bytes)
|
||||
* For each row:
|
||||
* Need to shift by: `num_lds_dwords * 4 - 1 - lane_id_in_row` (in bytes)
|
||||
* in order to implement an inclusive scan.
|
||||
*
|
||||
* When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
|
||||
|
|
@ -262,12 +263,21 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
|
|||
*/
|
||||
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
|
||||
/* Mask lane ID so that lanes 16...31 also have the ID 0...15,
|
||||
* in order to perform a second horizontal sum in parallel when needed.
|
||||
*/
|
||||
if (mask_lane_id)
|
||||
lane_id = nir_iand_imm(b, lane_id, 0xf);
|
||||
|
||||
nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
|
||||
assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
|
||||
bool use_dot = b->shader->options->has_udot_4x8;
|
||||
|
||||
if (num_lds_dwords == 1) {
|
||||
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
|
||||
/* Broadcast the packed data we read from LDS
|
||||
* (to the first 16 lanes of the row, but we only care up to num_waves).
|
||||
*/
|
||||
nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
|
||||
/* Horizontally add the packed bytes. */
|
||||
|
|
@ -279,7 +289,9 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
|
|||
return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
}
|
||||
} else if (num_lds_dwords == 2) {
|
||||
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
|
||||
/* Broadcast the packed data we read from LDS
|
||||
* (to the first 16 lanes of the row, but we only care up to num_waves).
|
||||
*/
|
||||
nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
|
||||
|
|
@ -301,46 +313,61 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
|
|||
/**
|
||||
* Repacks invocations in the current workgroup to eliminate gaps between them.
|
||||
*
|
||||
* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
|
||||
* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave) for each repack.
|
||||
* Assumes that all invocations in the workgroup are active (exec = -1).
|
||||
*/
|
||||
static void
|
||||
repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
|
||||
wg_repack_result *results,
|
||||
repack_invocations_in_workgroup(nir_builder *b, nir_def **input_bool,
|
||||
wg_repack_result *results, const unsigned num_repacks,
|
||||
nir_def *lds_addr_base, unsigned max_num_waves,
|
||||
unsigned wave_size)
|
||||
{
|
||||
/* Input boolean: 1 if the current invocation should survive the repack. */
|
||||
assert(input_bool->bit_size == 1);
|
||||
/* We can currently only do up to 2 repacks at a time. */
|
||||
assert(num_repacks <= 2);
|
||||
|
||||
/* STEP 1. Count surviving invocations in the current wave.
|
||||
*
|
||||
* Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
|
||||
*/
|
||||
|
||||
nir_def *input_mask = nir_ballot(b, 1, wave_size, input_bool);
|
||||
nir_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
|
||||
nir_def *input_mask[2];
|
||||
nir_def *surviving_invocations_in_current_wave[2];
|
||||
|
||||
for (unsigned i = 0; i < num_repacks; ++i) {
|
||||
/* Input should be boolean: 1 if the current invocation should survive the repack. */
|
||||
assert(input_bool[i]->bit_size == 1);
|
||||
|
||||
input_mask[i] = nir_ballot(b, 1, wave_size, input_bool[i]);
|
||||
surviving_invocations_in_current_wave[i] = nir_bit_count(b, input_mask[i]);
|
||||
}
|
||||
|
||||
/* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
|
||||
if (max_num_waves == 1) {
|
||||
results[0].num_repacked_invocations = surviving_invocations_in_current_wave;
|
||||
results[0].repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0));
|
||||
for (unsigned i = 0; i < num_repacks; ++i) {
|
||||
results[i].num_repacked_invocations = surviving_invocations_in_current_wave[i];
|
||||
results[i].repacked_invocation_index = nir_mbcnt_amd(b, input_mask[i], nir_imm_int(b, 0));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* STEP 2. Waves tell each other their number of surviving invocations.
|
||||
*
|
||||
* Each wave activates only its first lane (exec = 1), which stores the number of surviving
|
||||
* invocations in that wave into the LDS, then reads the numbers from every wave.
|
||||
* Row 0 (lanes 0-15) performs the first repack, and Row 1 (lanes 16-31) the second in parallel.
|
||||
* Each wave activates only its first lane per row, which stores the number of surviving
|
||||
* invocations in that wave into the LDS for that repack, then reads the numbers from every wave.
|
||||
*
|
||||
* The workgroup size of NGG shaders is at most 256, which means
|
||||
* the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
|
||||
* For each repack:
|
||||
* Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
|
||||
* (The maximum is 4 dwords for 2 repacks in Wave32 mode.)
|
||||
*/
|
||||
|
||||
const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
|
||||
assert(num_lds_dwords <= 2);
|
||||
const unsigned ballot = 1;
|
||||
|
||||
/* The first lane of each row (per repack) needs to access the LDS. */
|
||||
const unsigned ballot = num_repacks == 1 ? 1 : 0x10001;
|
||||
|
||||
nir_def *wave_id = nir_load_subgroup_id(b);
|
||||
nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
|
||||
|
|
@ -348,7 +375,16 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
|
|||
|
||||
nir_if *if_use_lds = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_intN_t(b, ballot, wave_size)));
|
||||
{
|
||||
nir_def *store_byte = nir_u2u8(b, surviving_invocations_in_current_wave);
|
||||
nir_def *store_val = surviving_invocations_in_current_wave[0];
|
||||
|
||||
if (num_repacks == 2) {
|
||||
nir_def *lane_id_0 = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 1, wave_size));
|
||||
nir_def *off = nir_bcsel(b, lane_id_0, nir_imm_int(b, 0), nir_imm_int(b, num_lds_dwords * 4));
|
||||
lds_addr_base = nir_iadd_nuw(b, lds_addr_base, off);
|
||||
store_val = nir_bcsel(b, lane_id_0, store_val, surviving_invocations_in_current_wave[1]);
|
||||
}
|
||||
|
||||
nir_def *store_byte = nir_u2u8(b, store_val);
|
||||
nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
|
||||
nir_store_shared(b, store_byte, lds_offset);
|
||||
|
||||
|
|
@ -366,26 +402,31 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
|
|||
* By now, every wave knows the number of surviving invocations in all waves.
|
||||
* Each number is 1 byte, and they are packed into up to 2 dwords.
|
||||
*
|
||||
* Each lane N will sum the number of surviving invocations inclusively from waves 0 to N.
|
||||
* If the workgroup has M waves, then each wave will use only its first M lanes for this.
|
||||
* For each row (of 16 lanes):
|
||||
* Each lane N (in the row) will sum the number of surviving invocations inclusively from waves 0 to N.
|
||||
* If the workgroup has M waves, then each row will use only its first M lanes for this.
|
||||
* (Other lanes are not deactivated but their calculation is not used.)
|
||||
*
|
||||
* - We read the sum from the lane whose id is the current wave's id,
|
||||
* - We read the sum from the lane whose id (in the row) is the current wave's id,
|
||||
* and subtract the number of its own surviving invocations.
|
||||
* Add the masked bitcount to this, and we get the repacked invocation index.
|
||||
* - We read the sum from the lane whose id is the number of waves in the workgroup minus 1.
|
||||
* - We read the sum from the lane whose id (in the row) is the number of waves in the workgroup minus 1.
|
||||
* This is the total number of surviving invocations in the workgroup.
|
||||
*/
|
||||
|
||||
nir_def *num_waves = nir_load_num_subgroups(b);
|
||||
nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
|
||||
nir_def *sum = summarize_repack(b, packed_counts, num_repacks == 2, num_lds_dwords);
|
||||
|
||||
nir_def *wg_repacked_index_base =
|
||||
nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave);
|
||||
results[0].num_repacked_invocations =
|
||||
nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1));
|
||||
results[0].repacked_invocation_index =
|
||||
nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
|
||||
for (unsigned i = 0; i < num_repacks; ++i) {
|
||||
nir_def *index_base_lane = nir_iadd_imm_nuw(b, wave_id, i * 16);
|
||||
nir_def *num_invocartions_lane = nir_iadd_imm_nuw(b, num_waves, i * 16 - 1);
|
||||
nir_def *wg_repacked_index_base =
|
||||
nir_isub(b, nir_read_invocation(b, sum, index_base_lane), surviving_invocations_in_current_wave[i]);
|
||||
results[i].num_repacked_invocations =
|
||||
nir_read_invocation(b, sum, num_invocartions_lane);
|
||||
results[i].repacked_invocation_index =
|
||||
nir_mbcnt_amd(b, input_mask[i], wg_repacked_index_base);
|
||||
}
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
|
|
@ -1610,11 +1651,12 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
|
|||
nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
|
||||
|
||||
/* Repack the vertices that survived the culling. */
|
||||
wg_repack_result rep = {0};
|
||||
repack_invocations_in_workgroup(b, es_accepted, &rep, lds_scratch_base,
|
||||
nir_def *accepted[] = { es_accepted };
|
||||
wg_repack_result rep[1] = {0};
|
||||
repack_invocations_in_workgroup(b, accepted, rep, 1, lds_scratch_base,
|
||||
s->max_num_waves, s->options->wave_size);
|
||||
nir_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
|
||||
nir_def *es_exporter_tid = rep.repacked_invocation_index;
|
||||
nir_def *num_live_vertices_in_workgroup = rep[0].num_repacked_invocations;
|
||||
nir_def *es_exporter_tid = rep[0].repacked_invocation_index;
|
||||
|
||||
/* If all vertices are culled, set primitive count to 0 as well. */
|
||||
nir_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
|
||||
|
|
@ -3428,7 +3470,7 @@ ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
|
|||
* stream..
|
||||
*/
|
||||
wg_repack_result rep = {0};
|
||||
repack_invocations_in_workgroup(b, prim_live[stream], &rep, scratch_base,
|
||||
repack_invocations_in_workgroup(b, &prim_live[stream], &rep, 1, scratch_base,
|
||||
s->max_num_waves, s->options->wave_size);
|
||||
|
||||
/* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
|
||||
|
|
@ -3533,7 +3575,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
|
|||
nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
|
||||
wg_repack_result rep = {0};
|
||||
|
||||
repack_invocations_in_workgroup(b, vertex_live, &rep, s->lds_addr_gs_scratch,
|
||||
repack_invocations_in_workgroup(b, &vertex_live, &rep, 1, s->lds_addr_gs_scratch,
|
||||
s->max_num_waves, s->options->wave_size);
|
||||
|
||||
nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue