diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 0181d1a701d..fffbdac49ef 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -68,103 +68,128 @@ typedef struct } lower_ngg_gs_state; typedef struct { - nir_ssa_def *reduction_result; - nir_ssa_def *excl_scan_result; -} wg_scan_result; + nir_ssa_def *num_repacked_invocations; + nir_ssa_def *repacked_invocation_index; +} wg_repack_result; -static wg_scan_result -workgroup_reduce_and_exclusive_scan(nir_builder *b, nir_ssa_def *input_bool, - unsigned lds_addr_base, unsigned max_num_waves) +/** + * Repacks invocations in the current workgroup to eliminate gaps between them. + * + * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave). + * Assumes that all invocations in the workgroup are active (exec = -1). + */ +static wg_repack_result +repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool, + unsigned lds_addr_base, unsigned max_num_waves) { - /* This performs a reduction along with an exclusive scan addition accross the workgroup. - * Assumes that all lanes are enabled (exec = -1) where this is emitted. - * - * Input: (1) divergent bool - * -- 1 if the lane has a live/valid vertex, 0 otherwise - * Output: (1) result of a reduction over the entire workgroup, - * -- the total number of vertices emitted by the workgroup - * (2) result of an exclusive scan over the entire workgroup - * -- used for vertex compaction, in order to determine - * which lane should export the current lane's vertex - */ - + /* Input boolean: 1 if the current invocation should survive the repack. */ assert(input_bool->bit_size == 1); - /* Reduce the boolean -- result is the number of live vertices in the current wave */ - nir_ssa_def *input_mask = nir_build_ballot(b, 1, 64, input_bool); - nir_ssa_def *wave_reduction = nir_bit_count(b, input_mask); + /* STEP 1. Count surviving invocations in the current wave. + * + * Implemented by a scalar instruction that simply counts the number of bits set in a 64-bit mask. + */ - /* Take care of when we know in compile time that the maximum workgroup size is small */ + nir_ssa_def *input_mask = nir_build_ballot(b, 1, 64, input_bool); + nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask); + + /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */ if (max_num_waves == 1) { - wg_scan_result r = { - .reduction_result = wave_reduction, - .excl_scan_result = nir_build_mbcnt_amd(b, input_mask), + wg_repack_result r = { + .num_repacked_invocations = surviving_invocations_in_current_wave, + .repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask), }; return r; } - /* Number of LDS dwords written by all waves (if there is only 1, that is already handled above) */ - unsigned num_lds_dwords = max_num_waves; - assert(num_lds_dwords >= 2 && num_lds_dwords <= 8); + /* STEP 2. Waves tell each other their number of surviving invocations. + * + * Each wave activates only its first lane (exec = 1), which stores the number of surviving + * invocations in that wave into the LDS, then reads the numbers from every wave. + * + * The workgroup size of NGG shaders is at most 256, which means + * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode. + * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary. + */ - /* NIR doesn't have vec6 and vec7 so just use 8 for these cases. */ - if (num_lds_dwords == 6 || num_lds_dwords == 7) - num_lds_dwords = 8; + const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4); + assert(num_lds_dwords <= 2); nir_ssa_def *wave_id = nir_build_load_subgroup_id(b); - nir_ssa_def *dont_care = nir_ssa_undef(b, num_lds_dwords, 32); + nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32); nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1)); - /* The first lane of each wave stores the result of its subgroup reduction to LDS (NGG scratch). */ - nir_ssa_def *wave_id_lds_addr = nir_imul_imm(b, wave_id, 4u); - nir_build_store_shared(b, wave_reduction, wave_id_lds_addr, .base = lds_addr_base, .align_mul = 4u, .write_mask = 0x1u); + nir_build_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base, .align_mul = 1u, .write_mask = 0x1u); - /* Workgroup barrier: wait for all waves to finish storing their result */ nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); - /* Only the first lane of each wave loads every wave's results from LDS, to avoid bank conflicts */ - nir_ssa_def *reduction_vector = nir_build_load_shared(b, num_lds_dwords, 32, nir_imm_zero(b, 1, 32), .base = lds_addr_base, .align_mul = 16u); + nir_ssa_def *packed_counts = nir_build_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u); + nir_pop_if(b, if_first_lane); - reduction_vector = nir_if_phi(b, reduction_vector, dont_care); + packed_counts = nir_if_phi(b, packed_counts, dont_care); - nir_ssa_def *reduction_per_wave[8] = {0}; - for (unsigned i = 0; i < num_lds_dwords; ++i) { - nir_ssa_def *reduction_wave_i = nir_channel(b, reduction_vector, i); - reduction_per_wave[i] = nir_build_read_first_invocation(b, reduction_wave_i); - } + /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations. + * + * By now, every wave knows the number of surviving invocations in all waves. + * Each number is 1 byte, and they are packed into up to 2 dwords. + * + * Each lane N will sum the number of surviving invocations from waves 0 to N-1. + * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this. + * (Other lanes are not deactivated but their calculation is not used.) + * + * - We read the sum from the lane whose id is the current wave's id. + * Add the masked bitcount to this, and we get the repacked invocation index. + * - We read the sum from the lane whose id is the number of waves in the workgroup. + * This is the total number of surviving invocations in the workgroup. + */ nir_ssa_def *num_waves = nir_build_load_num_subgroups(b); - nir_ssa_def *wg_reduction = reduction_per_wave[0]; - nir_ssa_def *wg_excl_scan_base = NULL; - for (unsigned i = 0; i < num_lds_dwords; ++i) { - /* Workgroup reduction: - * Add the reduction results from all waves up to and including wave_count. - */ - if (i != 0) { - nir_ssa_def *should_add = nir_ige(b, num_waves, nir_imm_int(b, i + 1u)); - nir_ssa_def *addition = nir_bcsel(b, should_add, reduction_per_wave[i], nir_imm_zero(b, 1, 32)); - wg_reduction = nir_iadd_nuw(b, wg_reduction, addition); - } + /* sel = 0x01010101 * lane_id + 0x03020100 */ + nir_ssa_def *lane_id = nir_load_subgroup_invocation(b); + nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0)); + nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100); + nir_ssa_def *sum = NULL; - /* Base of workgroup exclusive scan: - * Add the reduction results from waves up to and excluding wave_id_in_tg. - */ - if (i != (num_lds_dwords - 1u)) { - nir_ssa_def *should_add = nir_ige(b, wave_id, nir_imm_int(b, i + 1u)); - nir_ssa_def *addition = nir_bcsel(b, should_add, reduction_per_wave[i], nir_imm_zero(b, 1, 32)); - wg_excl_scan_base = !wg_excl_scan_base ? addition : nir_iadd_nuw(b, wg_excl_scan_base, addition); - } + if (num_lds_dwords == 1) { + /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ + nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0)); + + /* Use byte-permute to filter out the bytes not needed by the current lane. */ + nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel); + + /* Horizontally add the packed bytes. */ + sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0)); + } else if (num_lds_dwords == 2) { + /* Create selectors for the byte-permutes below. */ + nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4)); + nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4)); + + /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ + nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); + nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); + + /* Use byte-permute to filter out the bytes not needed by the current lane. */ + nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector); + nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector); + + /* Horizontally add the packed bytes. */ + sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0)); + sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum); + } else { + unreachable("Unimplemented NGG wave count"); } - nir_ssa_def *sg_excl_scan = nir_build_mbcnt_amd(b, input_mask); - nir_ssa_def *wg_excl_scan = nir_iadd_nuw(b, wg_excl_scan_base, sg_excl_scan); + nir_ssa_def *wave_repacked_index = nir_build_mbcnt_amd(b, input_mask); + nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id); + nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves); + nir_ssa_def *wg_repacked_index = nir_iadd_nuw(b, wg_repacked_index_base, wave_repacked_index); - wg_scan_result r = { - .reduction_result = wg_reduction, - .excl_scan_result = wg_excl_scan, + wg_repack_result r = { + .num_repacked_invocations = wg_num_repacked_invocations, + .repacked_invocation_index = wg_repacked_index, }; return r; @@ -789,17 +814,16 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) return; } - /* When the output is not known in compile time: there are gaps between the output vertices data in LDS. - * However, we need to make sure that the vertex exports are packed, meaning that there shouldn't be any gaps - * between the threads that perform the exports. We solve this using a perform a workgroup reduction + scan. + /* When the output vertex count is not known at compile time: + * There may be gaps between invocations that have live vertices, but NGG hardware + * requires that the invocations that export vertices are packed (ie. compact). + * To ensure this, we need to repack invocations that have a live vertex. */ nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size)); - wg_scan_result wg_scan = workgroup_reduce_and_exclusive_scan(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves); + wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves); - /* Reduction result = total number of vertices emitted in the workgroup. */ - nir_ssa_def *workgroup_num_vertices = wg_scan.reduction_result; - /* Exclusive scan result = the index of the thread in the workgroup that will export the current thread's vertex. */ - nir_ssa_def *exporter_tid_in_tg = wg_scan.excl_scan_result; + nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations; + nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index; /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */ nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0)); @@ -836,13 +860,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader, lower_ngg_gs_state state = { .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size), .lds_addr_gs_out_vtx = esgs_ring_lds_bytes, - .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 16u), + .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */), .lds_offs_primflags = gs_out_vtx_bytes, .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u, .provoking_vertex_last = provoking_vertex_last, }; - unsigned lds_scratch_bytes = state.max_num_waves * 4u; + unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u; unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes; shader->info.shared_size = total_lds_bytes;