diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 986df8f62d0..1a0d7b33e2e 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -240,8 +240,8 @@ typedef struct { /** * Computes a horizontal sum of 8-bit packed values loaded from LDS. * - * Each lane N will sum packed bytes 0 to N-1. - * We only care about the results from up to wave_id+1 lanes. + * Each lane N will sum packed bytes 0 to N. + * We only care about the results from up to wave_id lanes. * (Other lanes are not deactivated but their calculation is not used.) */ static nir_def * @@ -249,22 +249,21 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords { /* We'll use shift to filter out the bytes not needed by the current lane. * - * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes). - * However, two shifts are needed because one can't go all the way, - * so the shift amount is half that (and in bits). + * Need to shift by: `num_lds_dwords * 4 - 1 - lane_id` (in bytes) + * in order to implement an inclusive scan. * * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes. * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions, * therefore v_dot can get rid of the unneeded values. - * This sequence is preferable because it better hides the latency of the LDS. * - * If the v_dot instruction can't be used, we left-shift the packed bytes. - * This will shift out the unneeded bytes and shift in zeroes instead, + * If the v_dot instruction can't be used, we left-shift the packed bytes + * in order to shift out the unneeded bytes and shift in zeroes instead, * then we sum them using v_msad_u8. */ nir_def *lane_id = nir_load_subgroup_invocation(b); - nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16); + nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8); + assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8); bool use_dot = b->shader->options->has_udot_4x8; if (num_lds_dwords == 1) { @@ -273,10 +272,10 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords /* Horizontally add the packed bytes. */ if (use_dot) { - nir_def *dot_op = nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift); + nir_def *dot_op = nir_ushr(b, nir_imm_int(b, 0x01010101), shift); return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0)); } else { - nir_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift); + nir_def *sad_op = nir_ishl(b, packed, shift); return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0)); } } else if (num_lds_dwords == 2) { @@ -286,11 +285,11 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords /* Horizontally add the packed bytes. */ if (use_dot) { - nir_def *dot_op = nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift); + nir_def *dot_op = nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift); nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0)); return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum); } else { - nir_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift); + nir_def *sad_op = nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift); nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0)); return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum); } @@ -365,21 +364,24 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool, * By now, every wave knows the number of surviving invocations in all waves. * Each number is 1 byte, and they are packed into up to 2 dwords. * - * Each lane N will sum the number of surviving invocations from waves 0 to N-1. - * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this. + * Each lane N will sum the number of surviving invocations inclusively from waves 0 to N. + * If the workgroup has M waves, then each wave will use only its first M lanes for this. * (Other lanes are not deactivated but their calculation is not used.) * - * - We read the sum from the lane whose id is the current wave's id. + * - We read the sum from the lane whose id is the current wave's id, + * and subtract the number of its own surviving invocations. * Add the masked bitcount to this, and we get the repacked invocation index. - * - We read the sum from the lane whose id is the number of waves in the workgroup. + * - We read the sum from the lane whose id is the number of waves in the workgroup minus 1. * This is the total number of surviving invocations in the workgroup. */ nir_def *num_waves = nir_load_num_subgroups(b); nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords); - nir_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id); - nir_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves); + nir_def *wg_repacked_index_base = + nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave); + nir_def *wg_num_repacked_invocations = + nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1)); nir_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base); wg_repack_result r = {