diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 266b2ac26ba..d4f97728d9e 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -245,11 +245,12 @@ typedef struct { * (Other lanes are not deactivated but their calculation is not used.) */ static nir_def * -summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords) +summarize_repack(nir_builder *b, nir_def *packed_counts, bool mask_lane_id, unsigned num_lds_dwords) { /* We'll use shift to filter out the bytes not needed by the current lane. * - * Need to shift by: `num_lds_dwords * 4 - 1 - lane_id` (in bytes) + * For each row: + * Need to shift by: `num_lds_dwords * 4 - 1 - lane_id_in_row` (in bytes) * in order to implement an inclusive scan. * * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes. @@ -262,12 +263,21 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords */ nir_def *lane_id = nir_load_subgroup_invocation(b); + + /* Mask lane ID so that lanes 16...31 also have the ID 0...15, + * in order to perform a second horizontal sum in parallel when needed. + */ + if (mask_lane_id) + lane_id = nir_iand_imm(b, lane_id, 0xf); + nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8); assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8); bool use_dot = b->shader->options->has_udot_4x8; if (num_lds_dwords == 1) { - /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ + /* Broadcast the packed data we read from LDS + * (to the first 16 lanes of the row, but we only care up to num_waves). + */ nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0)); /* Horizontally add the packed bytes. */ @@ -279,7 +289,9 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0)); } } else if (num_lds_dwords == 2) { - /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ + /* Broadcast the packed data we read from LDS + * (to the first 16 lanes of the row, but we only care up to num_waves). + */ nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); @@ -301,46 +313,61 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords /** * Repacks invocations in the current workgroup to eliminate gaps between them. * - * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave). + * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave) for each repack. * Assumes that all invocations in the workgroup are active (exec = -1). */ static void -repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool, - wg_repack_result *results, +repack_invocations_in_workgroup(nir_builder *b, nir_def **input_bool, + wg_repack_result *results, const unsigned num_repacks, nir_def *lds_addr_base, unsigned max_num_waves, unsigned wave_size) { - /* Input boolean: 1 if the current invocation should survive the repack. */ - assert(input_bool->bit_size == 1); + /* We can currently only do up to 2 repacks at a time. */ + assert(num_repacks <= 2); /* STEP 1. Count surviving invocations in the current wave. * * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask. */ - nir_def *input_mask = nir_ballot(b, 1, wave_size, input_bool); - nir_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask); + nir_def *input_mask[2]; + nir_def *surviving_invocations_in_current_wave[2]; + + for (unsigned i = 0; i < num_repacks; ++i) { + /* Input should be boolean: 1 if the current invocation should survive the repack. */ + assert(input_bool[i]->bit_size == 1); + + input_mask[i] = nir_ballot(b, 1, wave_size, input_bool[i]); + surviving_invocations_in_current_wave[i] = nir_bit_count(b, input_mask[i]); + } /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */ if (max_num_waves == 1) { - results[0].num_repacked_invocations = surviving_invocations_in_current_wave; - results[0].repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)); + for (unsigned i = 0; i < num_repacks; ++i) { + results[i].num_repacked_invocations = surviving_invocations_in_current_wave[i]; + results[i].repacked_invocation_index = nir_mbcnt_amd(b, input_mask[i], nir_imm_int(b, 0)); + } return; } /* STEP 2. Waves tell each other their number of surviving invocations. * - * Each wave activates only its first lane (exec = 1), which stores the number of surviving - * invocations in that wave into the LDS, then reads the numbers from every wave. + * Row 0 (lanes 0-15) performs the first repack, and Row 1 (lanes 16-31) the second in parallel. + * Each wave activates only its first lane per row, which stores the number of surviving + * invocations in that wave into the LDS for that repack, then reads the numbers from every wave. * * The workgroup size of NGG shaders is at most 256, which means * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode. + * For each repack: * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary. + * (The maximum is 4 dwords for 2 repacks in Wave32 mode.) */ const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4); assert(num_lds_dwords <= 2); - const unsigned ballot = 1; + + /* The first lane of each row (per repack) needs to access the LDS. */ + const unsigned ballot = num_repacks == 1 ? 1 : 0x10001; nir_def *wave_id = nir_load_subgroup_id(b); nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32); @@ -348,7 +375,16 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool, nir_if *if_use_lds = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_intN_t(b, ballot, wave_size))); { - nir_def *store_byte = nir_u2u8(b, surviving_invocations_in_current_wave); + nir_def *store_val = surviving_invocations_in_current_wave[0]; + + if (num_repacks == 2) { + nir_def *lane_id_0 = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 1, wave_size)); + nir_def *off = nir_bcsel(b, lane_id_0, nir_imm_int(b, 0), nir_imm_int(b, num_lds_dwords * 4)); + lds_addr_base = nir_iadd_nuw(b, lds_addr_base, off); + store_val = nir_bcsel(b, lane_id_0, store_val, surviving_invocations_in_current_wave[1]); + } + + nir_def *store_byte = nir_u2u8(b, store_val); nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id); nir_store_shared(b, store_byte, lds_offset); @@ -366,26 +402,31 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool, * By now, every wave knows the number of surviving invocations in all waves. * Each number is 1 byte, and they are packed into up to 2 dwords. * - * Each lane N will sum the number of surviving invocations inclusively from waves 0 to N. - * If the workgroup has M waves, then each wave will use only its first M lanes for this. + * For each row (of 16 lanes): + * Each lane N (in the row) will sum the number of surviving invocations inclusively from waves 0 to N. + * If the workgroup has M waves, then each row will use only its first M lanes for this. * (Other lanes are not deactivated but their calculation is not used.) * - * - We read the sum from the lane whose id is the current wave's id, + * - We read the sum from the lane whose id (in the row) is the current wave's id, * and subtract the number of its own surviving invocations. * Add the masked bitcount to this, and we get the repacked invocation index. - * - We read the sum from the lane whose id is the number of waves in the workgroup minus 1. + * - We read the sum from the lane whose id (in the row) is the number of waves in the workgroup minus 1. * This is the total number of surviving invocations in the workgroup. */ nir_def *num_waves = nir_load_num_subgroups(b); - nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords); + nir_def *sum = summarize_repack(b, packed_counts, num_repacks == 2, num_lds_dwords); - nir_def *wg_repacked_index_base = - nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave); - results[0].num_repacked_invocations = - nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1)); - results[0].repacked_invocation_index = - nir_mbcnt_amd(b, input_mask, wg_repacked_index_base); + for (unsigned i = 0; i < num_repacks; ++i) { + nir_def *index_base_lane = nir_iadd_imm_nuw(b, wave_id, i * 16); + nir_def *num_invocartions_lane = nir_iadd_imm_nuw(b, num_waves, i * 16 - 1); + nir_def *wg_repacked_index_base = + nir_isub(b, nir_read_invocation(b, sum, index_base_lane), surviving_invocations_in_current_wave[i]); + results[i].num_repacked_invocations = + nir_read_invocation(b, sum, num_invocartions_lane); + results[i].repacked_invocation_index = + nir_mbcnt_amd(b, input_mask[i], wg_repacked_index_base); + } } static nir_def * @@ -1610,11 +1651,12 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c nir_def *es_accepted = nir_load_var(b, s->es_accepted_var); /* Repack the vertices that survived the culling. */ - wg_repack_result rep = {0}; - repack_invocations_in_workgroup(b, es_accepted, &rep, lds_scratch_base, + nir_def *accepted[] = { es_accepted }; + wg_repack_result rep[1] = {0}; + repack_invocations_in_workgroup(b, accepted, rep, 1, lds_scratch_base, s->max_num_waves, s->options->wave_size); - nir_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations; - nir_def *es_exporter_tid = rep.repacked_invocation_index; + nir_def *num_live_vertices_in_workgroup = rep[0].num_repacked_invocations; + nir_def *es_exporter_tid = rep[0].repacked_invocation_index; /* If all vertices are culled, set primitive count to 0 as well. */ nir_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b); @@ -3428,7 +3470,7 @@ ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s) * stream.. */ wg_repack_result rep = {0}; - repack_invocations_in_workgroup(b, prim_live[stream], &rep, scratch_base, + repack_invocations_in_workgroup(b, &prim_live[stream], &rep, 1, scratch_base, s->max_num_waves, s->options->wave_size); /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of @@ -3533,7 +3575,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0); wg_repack_result rep = {0}; - repack_invocations_in_workgroup(b, vertex_live, &rep, s->lds_addr_gs_scratch, + repack_invocations_in_workgroup(b, &vertex_live, &rep, 1, s->lds_addr_gs_scratch, s->max_num_waves, s->options->wave_size); nir_def *workgroup_num_vertices = rep.num_repacked_invocations;