ac/nir/ngg: Trade 1 VALU shift for 2 SALU add.

Change the workgroup scan to be inclusive and adjust
the scalar operations after it.
This gets rid of 1 VALU instruction for 2 SALU. Win!

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Acked-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31973>
This commit is contained in:
Timur Kristóf 2024-11-05 02:25:44 +01:00
parent 340ec61984
commit 218c824e27

View file

@ -240,8 +240,8 @@ typedef struct {
/**
* Computes a horizontal sum of 8-bit packed values loaded from LDS.
*
* Each lane N will sum packed bytes 0 to N-1.
* We only care about the results from up to wave_id+1 lanes.
* Each lane N will sum packed bytes 0 to N.
* We only care about the results from up to wave_id lanes.
* (Other lanes are not deactivated but their calculation is not used.)
*/
static nir_def *
@ -249,22 +249,21 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
{
/* We'll use shift to filter out the bytes not needed by the current lane.
*
* Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
* However, two shifts are needed because one can't go all the way,
* so the shift amount is half that (and in bits).
* Need to shift by: `num_lds_dwords * 4 - 1 - lane_id` (in bytes)
* in order to implement an inclusive scan.
*
* When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
* This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
* therefore v_dot can get rid of the unneeded values.
* This sequence is preferable because it better hides the latency of the LDS.
*
* If the v_dot instruction can't be used, we left-shift the packed bytes.
* This will shift out the unneeded bytes and shift in zeroes instead,
* If the v_dot instruction can't be used, we left-shift the packed bytes
* in order to shift out the unneeded bytes and shift in zeroes instead,
* then we sum them using v_msad_u8.
*/
nir_def *lane_id = nir_load_subgroup_invocation(b);
nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
bool use_dot = b->shader->options->has_udot_4x8;
if (num_lds_dwords == 1) {
@ -273,10 +272,10 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
/* Horizontally add the packed bytes. */
if (use_dot) {
nir_def *dot_op = nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
nir_def *dot_op = nir_ushr(b, nir_imm_int(b, 0x01010101), shift);
return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
} else {
nir_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
nir_def *sad_op = nir_ishl(b, packed, shift);
return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
}
} else if (num_lds_dwords == 2) {
@ -286,11 +285,11 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
/* Horizontally add the packed bytes. */
if (use_dot) {
nir_def *dot_op = nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
nir_def *dot_op = nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift);
nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
} else {
nir_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
nir_def *sad_op = nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift);
nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
}
@ -365,21 +364,24 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
* By now, every wave knows the number of surviving invocations in all waves.
* Each number is 1 byte, and they are packed into up to 2 dwords.
*
* Each lane N will sum the number of surviving invocations from waves 0 to N-1.
* If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
* Each lane N will sum the number of surviving invocations inclusively from waves 0 to N.
* If the workgroup has M waves, then each wave will use only its first M lanes for this.
* (Other lanes are not deactivated but their calculation is not used.)
*
* - We read the sum from the lane whose id is the current wave's id.
* - We read the sum from the lane whose id is the current wave's id,
* and subtract the number of its own surviving invocations.
* Add the masked bitcount to this, and we get the repacked invocation index.
* - We read the sum from the lane whose id is the number of waves in the workgroup.
* - We read the sum from the lane whose id is the number of waves in the workgroup minus 1.
* This is the total number of surviving invocations in the workgroup.
*/
nir_def *num_waves = nir_load_num_subgroups(b);
nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
nir_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id);
nir_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves);
nir_def *wg_repacked_index_base =
nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave);
nir_def *wg_num_repacked_invocations =
nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1));
nir_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
wg_repack_result r = {