ac/nir/ngg: Workgroup scan over two bools.

Implement two workgroup scans over two boolean values in parallel,
so that they can be done with very minimal ALU overhead.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32290>
This commit is contained in:
Timur Kristóf 2024-11-04 17:13:27 +01:00
parent 78f77e161c
commit 492d8f3778

View file

@ -245,11 +245,12 @@ typedef struct {
* (Other lanes are not deactivated but their calculation is not used.)
*/
static nir_def *
summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords)
summarize_repack(nir_builder *b, nir_def *packed_counts, bool mask_lane_id, unsigned num_lds_dwords)
{
/* We'll use shift to filter out the bytes not needed by the current lane.
*
* Need to shift by: `num_lds_dwords * 4 - 1 - lane_id` (in bytes)
* For each row:
* Need to shift by: `num_lds_dwords * 4 - 1 - lane_id_in_row` (in bytes)
* in order to implement an inclusive scan.
*
* When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
@ -262,12 +263,21 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
*/
nir_def *lane_id = nir_load_subgroup_invocation(b);
/* Mask lane ID so that lanes 16...31 also have the ID 0...15,
* in order to perform a second horizontal sum in parallel when needed.
*/
if (mask_lane_id)
lane_id = nir_iand_imm(b, lane_id, 0xf);
nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
bool use_dot = b->shader->options->has_udot_4x8;
if (num_lds_dwords == 1) {
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
/* Broadcast the packed data we read from LDS
* (to the first 16 lanes of the row, but we only care up to num_waves).
*/
nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
/* Horizontally add the packed bytes. */
@ -279,7 +289,9 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
}
} else if (num_lds_dwords == 2) {
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
/* Broadcast the packed data we read from LDS
* (to the first 16 lanes of the row, but we only care up to num_waves).
*/
nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
@ -301,46 +313,61 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
/**
* Repacks invocations in the current workgroup to eliminate gaps between them.
*
* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave) for each repack.
* Assumes that all invocations in the workgroup are active (exec = -1).
*/
static void
repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
wg_repack_result *results,
repack_invocations_in_workgroup(nir_builder *b, nir_def **input_bool,
wg_repack_result *results, const unsigned num_repacks,
nir_def *lds_addr_base, unsigned max_num_waves,
unsigned wave_size)
{
/* Input boolean: 1 if the current invocation should survive the repack. */
assert(input_bool->bit_size == 1);
/* We can currently only do up to 2 repacks at a time. */
assert(num_repacks <= 2);
/* STEP 1. Count surviving invocations in the current wave.
*
* Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
*/
nir_def *input_mask = nir_ballot(b, 1, wave_size, input_bool);
nir_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
nir_def *input_mask[2];
nir_def *surviving_invocations_in_current_wave[2];
for (unsigned i = 0; i < num_repacks; ++i) {
/* Input should be boolean: 1 if the current invocation should survive the repack. */
assert(input_bool[i]->bit_size == 1);
input_mask[i] = nir_ballot(b, 1, wave_size, input_bool[i]);
surviving_invocations_in_current_wave[i] = nir_bit_count(b, input_mask[i]);
}
/* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
if (max_num_waves == 1) {
results[0].num_repacked_invocations = surviving_invocations_in_current_wave;
results[0].repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0));
for (unsigned i = 0; i < num_repacks; ++i) {
results[i].num_repacked_invocations = surviving_invocations_in_current_wave[i];
results[i].repacked_invocation_index = nir_mbcnt_amd(b, input_mask[i], nir_imm_int(b, 0));
}
return;
}
/* STEP 2. Waves tell each other their number of surviving invocations.
*
* Each wave activates only its first lane (exec = 1), which stores the number of surviving
* invocations in that wave into the LDS, then reads the numbers from every wave.
* Row 0 (lanes 0-15) performs the first repack, and Row 1 (lanes 16-31) the second in parallel.
* Each wave activates only its first lane per row, which stores the number of surviving
* invocations in that wave into the LDS for that repack, then reads the numbers from every wave.
*
* The workgroup size of NGG shaders is at most 256, which means
* the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
* For each repack:
* Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
* (The maximum is 4 dwords for 2 repacks in Wave32 mode.)
*/
const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
assert(num_lds_dwords <= 2);
const unsigned ballot = 1;
/* The first lane of each row (per repack) needs to access the LDS. */
const unsigned ballot = num_repacks == 1 ? 1 : 0x10001;
nir_def *wave_id = nir_load_subgroup_id(b);
nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
@ -348,7 +375,16 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
nir_if *if_use_lds = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_intN_t(b, ballot, wave_size)));
{
nir_def *store_byte = nir_u2u8(b, surviving_invocations_in_current_wave);
nir_def *store_val = surviving_invocations_in_current_wave[0];
if (num_repacks == 2) {
nir_def *lane_id_0 = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 1, wave_size));
nir_def *off = nir_bcsel(b, lane_id_0, nir_imm_int(b, 0), nir_imm_int(b, num_lds_dwords * 4));
lds_addr_base = nir_iadd_nuw(b, lds_addr_base, off);
store_val = nir_bcsel(b, lane_id_0, store_val, surviving_invocations_in_current_wave[1]);
}
nir_def *store_byte = nir_u2u8(b, store_val);
nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
nir_store_shared(b, store_byte, lds_offset);
@ -366,26 +402,31 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
* By now, every wave knows the number of surviving invocations in all waves.
* Each number is 1 byte, and they are packed into up to 2 dwords.
*
* Each lane N will sum the number of surviving invocations inclusively from waves 0 to N.
* If the workgroup has M waves, then each wave will use only its first M lanes for this.
* For each row (of 16 lanes):
* Each lane N (in the row) will sum the number of surviving invocations inclusively from waves 0 to N.
* If the workgroup has M waves, then each row will use only its first M lanes for this.
* (Other lanes are not deactivated but their calculation is not used.)
*
* - We read the sum from the lane whose id is the current wave's id,
* - We read the sum from the lane whose id (in the row) is the current wave's id,
* and subtract the number of its own surviving invocations.
* Add the masked bitcount to this, and we get the repacked invocation index.
* - We read the sum from the lane whose id is the number of waves in the workgroup minus 1.
* - We read the sum from the lane whose id (in the row) is the number of waves in the workgroup minus 1.
* This is the total number of surviving invocations in the workgroup.
*/
nir_def *num_waves = nir_load_num_subgroups(b);
nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
nir_def *sum = summarize_repack(b, packed_counts, num_repacks == 2, num_lds_dwords);
nir_def *wg_repacked_index_base =
nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave);
results[0].num_repacked_invocations =
nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1));
results[0].repacked_invocation_index =
nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
for (unsigned i = 0; i < num_repacks; ++i) {
nir_def *index_base_lane = nir_iadd_imm_nuw(b, wave_id, i * 16);
nir_def *num_invocartions_lane = nir_iadd_imm_nuw(b, num_waves, i * 16 - 1);
nir_def *wg_repacked_index_base =
nir_isub(b, nir_read_invocation(b, sum, index_base_lane), surviving_invocations_in_current_wave[i]);
results[i].num_repacked_invocations =
nir_read_invocation(b, sum, num_invocartions_lane);
results[i].repacked_invocation_index =
nir_mbcnt_amd(b, input_mask[i], wg_repacked_index_base);
}
}
static nir_def *
@ -1610,11 +1651,12 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
/* Repack the vertices that survived the culling. */
wg_repack_result rep = {0};
repack_invocations_in_workgroup(b, es_accepted, &rep, lds_scratch_base,
nir_def *accepted[] = { es_accepted };
wg_repack_result rep[1] = {0};
repack_invocations_in_workgroup(b, accepted, rep, 1, lds_scratch_base,
s->max_num_waves, s->options->wave_size);
nir_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
nir_def *es_exporter_tid = rep.repacked_invocation_index;
nir_def *num_live_vertices_in_workgroup = rep[0].num_repacked_invocations;
nir_def *es_exporter_tid = rep[0].repacked_invocation_index;
/* If all vertices are culled, set primitive count to 0 as well. */
nir_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
@ -3428,7 +3470,7 @@ ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
* stream..
*/
wg_repack_result rep = {0};
repack_invocations_in_workgroup(b, prim_live[stream], &rep, scratch_base,
repack_invocations_in_workgroup(b, &prim_live[stream], &rep, 1, scratch_base,
s->max_num_waves, s->options->wave_size);
/* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
@ -3533,7 +3575,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
wg_repack_result rep = {0};
repack_invocations_in_workgroup(b, vertex_live, &rep, s->lds_addr_gs_scratch,
repack_invocations_in_workgroup(b, &vertex_live, &rep, 1, s->lds_addr_gs_scratch,
s->max_num_waves, s->options->wave_size);
nir_def *workgroup_num_vertices = rep.num_repacked_invocations;