mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-08 06:58:05 +02:00
ac/nir/ngg: Trade 1 VALU shift for 2 SALU add.
Change the workgroup scan to be inclusive and adjust the scalar operations after it. This gets rid of 1 VALU instruction for 2 SALU. Win! Signed-off-by: Timur Kristóf <timur.kristof@gmail.com> Acked-by: Marek Olšák <marek.olsak@amd.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31973>
This commit is contained in:
parent
340ec61984
commit
218c824e27
1 changed files with 21 additions and 19 deletions
|
|
@ -240,8 +240,8 @@ typedef struct {
|
|||
/**
|
||||
* Computes a horizontal sum of 8-bit packed values loaded from LDS.
|
||||
*
|
||||
* Each lane N will sum packed bytes 0 to N-1.
|
||||
* We only care about the results from up to wave_id+1 lanes.
|
||||
* Each lane N will sum packed bytes 0 to N.
|
||||
* We only care about the results from up to wave_id lanes.
|
||||
* (Other lanes are not deactivated but their calculation is not used.)
|
||||
*/
|
||||
static nir_def *
|
||||
|
|
@ -249,22 +249,21 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
|
|||
{
|
||||
/* We'll use shift to filter out the bytes not needed by the current lane.
|
||||
*
|
||||
* Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
|
||||
* However, two shifts are needed because one can't go all the way,
|
||||
* so the shift amount is half that (and in bits).
|
||||
* Need to shift by: `num_lds_dwords * 4 - 1 - lane_id` (in bytes)
|
||||
* in order to implement an inclusive scan.
|
||||
*
|
||||
* When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
|
||||
* This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
|
||||
* therefore v_dot can get rid of the unneeded values.
|
||||
* This sequence is preferable because it better hides the latency of the LDS.
|
||||
*
|
||||
* If the v_dot instruction can't be used, we left-shift the packed bytes.
|
||||
* This will shift out the unneeded bytes and shift in zeroes instead,
|
||||
* If the v_dot instruction can't be used, we left-shift the packed bytes
|
||||
* in order to shift out the unneeded bytes and shift in zeroes instead,
|
||||
* then we sum them using v_msad_u8.
|
||||
*/
|
||||
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
|
||||
nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
|
||||
assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
|
||||
bool use_dot = b->shader->options->has_udot_4x8;
|
||||
|
||||
if (num_lds_dwords == 1) {
|
||||
|
|
@ -273,10 +272,10 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
|
|||
|
||||
/* Horizontally add the packed bytes. */
|
||||
if (use_dot) {
|
||||
nir_def *dot_op = nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
|
||||
nir_def *dot_op = nir_ushr(b, nir_imm_int(b, 0x01010101), shift);
|
||||
return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
|
||||
} else {
|
||||
nir_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
|
||||
nir_def *sad_op = nir_ishl(b, packed, shift);
|
||||
return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
}
|
||||
} else if (num_lds_dwords == 2) {
|
||||
|
|
@ -286,11 +285,11 @@ summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords
|
|||
|
||||
/* Horizontally add the packed bytes. */
|
||||
if (use_dot) {
|
||||
nir_def *dot_op = nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
|
||||
nir_def *dot_op = nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift);
|
||||
nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
|
||||
return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
|
||||
} else {
|
||||
nir_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
|
||||
nir_def *sad_op = nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift);
|
||||
nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
|
||||
}
|
||||
|
|
@ -365,21 +364,24 @@ repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
|
|||
* By now, every wave knows the number of surviving invocations in all waves.
|
||||
* Each number is 1 byte, and they are packed into up to 2 dwords.
|
||||
*
|
||||
* Each lane N will sum the number of surviving invocations from waves 0 to N-1.
|
||||
* If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
|
||||
* Each lane N will sum the number of surviving invocations inclusively from waves 0 to N.
|
||||
* If the workgroup has M waves, then each wave will use only its first M lanes for this.
|
||||
* (Other lanes are not deactivated but their calculation is not used.)
|
||||
*
|
||||
* - We read the sum from the lane whose id is the current wave's id.
|
||||
* - We read the sum from the lane whose id is the current wave's id,
|
||||
* and subtract the number of its own surviving invocations.
|
||||
* Add the masked bitcount to this, and we get the repacked invocation index.
|
||||
* - We read the sum from the lane whose id is the number of waves in the workgroup.
|
||||
* - We read the sum from the lane whose id is the number of waves in the workgroup minus 1.
|
||||
* This is the total number of surviving invocations in the workgroup.
|
||||
*/
|
||||
|
||||
nir_def *num_waves = nir_load_num_subgroups(b);
|
||||
nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
|
||||
|
||||
nir_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id);
|
||||
nir_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves);
|
||||
nir_def *wg_repacked_index_base =
|
||||
nir_isub(b, nir_read_invocation(b, sum, wave_id), surviving_invocations_in_current_wave);
|
||||
nir_def *wg_num_repacked_invocations =
|
||||
nir_read_invocation(b, sum, nir_iadd_imm(b, num_waves, -1));
|
||||
nir_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
|
||||
|
||||
wg_repack_result r = {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue