ac/nir: add new helpers for computing the TCS LDS/offchip size accurately

This is based on how the HS lowering passes address TCS inputs and
outputs. The new LDS size is lower in some cases.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31673>
This commit is contained in:
Marek Olšák 2024-10-14 21:05:31 -04:00 committed by Marge Bot
parent 85c20def94
commit 3056bf1cb1
3 changed files with 54 additions and 2 deletions

View file

@ -104,6 +104,13 @@ void
ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map);
void
ac_nir_compute_tess_wg_info(const struct radeon_info *info, const struct shader_info *tcs_info,
unsigned wave_size, bool tess_uses_primid, bool all_invocations_define_tess_levels,
unsigned num_tcs_input_cp, unsigned lds_input_vertex_size,
unsigned num_mem_tcs_outputs, unsigned num_mem_tcs_patch_outputs,
unsigned *num_patches_per_wg, unsigned *hw_lds_size);
void
ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,

View file

@ -4,9 +4,11 @@
* SPDX-License-Identifier: MIT
*/
#include "ac_gpu_info.h"
#include "ac_nir.h"
#include "ac_nir_helpers.h"
#include "nir_builder.h"
#include "util/u_math.h"
/*
* These NIR passes are used to lower NIR cross-stage I/O intrinsics into the
@ -1282,3 +1284,44 @@ ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
lower_tes_input_load,
&state);
}
void
ac_nir_compute_tess_wg_info(const struct radeon_info *info, const struct shader_info *tcs_info,
unsigned wave_size, bool tess_uses_primid, bool all_invocations_define_tess_levels,
unsigned num_tcs_input_cp, unsigned lds_input_vertex_size,
unsigned num_mem_tcs_outputs, unsigned num_mem_tcs_patch_outputs,
unsigned *num_patches_per_wg, unsigned *hw_lds_size)
{
unsigned num_tcs_output_cp = tcs_info->tess.tcs_vertices_out;
unsigned lds_output_vertex_size =
util_bitcount64(tcs_info->outputs_read & tcs_info->outputs_written & ~TESS_LVL_MASK) * 16;
unsigned lds_perpatch_output_patch_size =
(util_bitcount64(all_invocations_define_tess_levels ?
0 : tcs_info->outputs_written & TESS_LVL_MASK) +
util_bitcount(tcs_info->patch_outputs_read & tcs_info->patch_outputs_written)) * 16;
unsigned lds_per_patch = num_tcs_input_cp * lds_input_vertex_size +
num_tcs_output_cp * lds_output_vertex_size +
lds_perpatch_output_patch_size;
unsigned mem_per_patch = (num_tcs_output_cp * num_mem_tcs_outputs + num_mem_tcs_patch_outputs) * 16;
unsigned num_patches = ac_compute_num_tess_patches(info, num_tcs_input_cp, num_tcs_output_cp, mem_per_patch,
lds_per_patch, wave_size, tess_uses_primid);
unsigned lds_size = lds_per_patch * num_patches;
unsigned mem_size = mem_per_patch * num_patches;
/* The first vec4 is reserved for the tf0/1 shader message group vote. */
if (info->gfx_level >= GFX11)
lds_size += AC_HS_MSG_VOTE_LDS_BYTES;
/* SPI_SHADER_PGM_RSRC2_HS.LDS_SIZE specifies the allocation size for both LDS and the HS
* offchip ring buffer. LDS is only used for TCS inputs (with cross-invocation or indirect
* access only or if TCS in/out vertex counts are different) and for TCS outputs that are read
* (including tess level outputs if they need to be re-read in invocation 0), while the HS ring
* buffer is only used for TCS outputs consumed by TES.
*/
unsigned merged_size = MAX2(lds_size, mem_size);
assert(merged_size <= (info->gfx_level >= GFX9 ? 65536 : 32768));
*num_patches_per_wg = num_patches;
*hw_lds_size = DIV_ROUND_UP(merged_size, info->lds_encode_granularity);
}

View file

@ -1180,8 +1180,10 @@ uint32_t ac_compute_num_tess_patches(const struct radeon_info *info, uint32_t nu
* use LDS for the inputs and outputs.
*/
if (lds_per_patch) {
ASSERTED const unsigned max_lds_size = info->gfx_level >= GFX9 ? 64 * 1024 : 32 * 1024; /* hw limit */
const unsigned target_lds_size = max_lds_size / 2; /* target at least 2 workgroups per CU */
const unsigned max_lds_size = (info->gfx_level >= GFX9 ? 64 * 1024 : 32 * 1024); /* hw limit */
/* Target at least 2 workgroups per CU. */
const unsigned target_lds_size = max_lds_size / 2 -
(info->gfx_level >= GFX11 ? AC_HS_MSG_VOTE_LDS_BYTES : 0);
num_patches = MIN2(num_patches, target_lds_size / lds_per_patch);
assert(num_patches * lds_per_patch <= max_lds_size);
}