radeonsi: pass TCS inputs_read mask to LS output lowering on GFX9 + monolithic

This will allocate less LDS for LS outputs if there are holes between
varyings when we have monolithic merged LS+TCS. (it removes the holes)

There are 2 steps to this:
- add helper si_shader_lshs_vertex_stride and use it everywhere
- pass the TCS inputs_read bitmask instead of the "map" callback
  to si_lower_ls_outputs_mem

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30962>
This commit is contained in:
Marek Olšák 2024-08-23 15:34:29 -04:00 committed by Marge Bot
parent 52c41f25de
commit b7136d0890
5 changed files with 54 additions and 28 deletions

View file

@ -324,13 +324,14 @@ static bool lower_intrinsic(nir_builder *b, nir_instr *instr, struct lower_abi_s
break;
case nir_intrinsic_load_lshs_vertex_stride_amd:
if (stage == MESA_SHADER_VERTEX) {
replacement = nir_imm_int(b, sel->info.lshs_vertex_stride);
replacement = nir_imm_int(b, si_shader_lshs_vertex_stride(shader));
} else if (stage == MESA_SHADER_TESS_CTRL) {
if (sel->screen->info.gfx_level >= GFX9 && shader->is_monolithic) {
replacement = nir_imm_int(b, key->ge.part.tcs.ls->info.lshs_vertex_stride);
replacement = nir_imm_int(b, si_shader_lshs_vertex_stride(shader));
} else {
nir_def *num_ls_out = ac_nir_unpack_arg(b, &args->ac, args->tcs_offchip_layout, 17, 6);
replacement = nir_iadd_imm_nuw(b, nir_ishl_imm(b, num_ls_out, 4), 4);
nir_def *extra_dw = nir_bcsel(b, nir_ieq_imm(b, num_ls_out, 0), nir_imm_int(b, 0), nir_imm_int(b, 4));
replacement = nir_iadd_nuw(b, nir_ishl_imm(b, num_ls_out, 4), extra_dw);
}
} else {
unreachable("no nir_load_lshs_vertex_stride_amd");

View file

@ -1307,7 +1307,7 @@ void si_shader_dump_stats_for_shader_db(struct si_screen *screen, struct si_shad
* for performance and can be optimized.
*/
if (shader->key.ge.as_ls)
num_ls_outputs = shader->selector->info.lshs_vertex_stride / 16;
num_ls_outputs = si_shader_lshs_vertex_stride(shader) / 16;
else if (shader->selector->stage == MESA_SHADER_TESS_CTRL)
num_hs_outputs = util_last_bit64(shader->selector->info.outputs_written_before_tes_gs);
else if (shader->key.ge.as_es)
@ -1843,11 +1843,15 @@ static bool si_lower_io_to_mem(struct si_shader *shader, nir_shader *nir,
{
struct si_shader_selector *sel = shader->selector;
const union si_shader_key *key = &shader->key;
const bool is_gfx9_mono_tcs = sel->stage == MESA_SHADER_TESS_CTRL && shader->is_monolithic &&
sel->screen->info.gfx_level >= GFX9;
if (nir->info.stage == MESA_SHADER_VERTEX) {
if (key->ge.as_ls) {
NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, si_map_io_driver_location,
key->ge.opt.same_patch_vertices, ~0ULL, tcs_vgpr_only_inputs);
NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem,
is_gfx9_mono_tcs ? NULL : si_map_io_driver_location,
key->ge.opt.same_patch_vertices,
is_gfx9_mono_tcs ? sel->info.base.inputs_read : ~0ull, tcs_vgpr_only_inputs);
return true;
} else if (key->ge.as_es) {
NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, si_map_io_driver_location,
@ -1855,7 +1859,8 @@ static bool si_lower_io_to_mem(struct si_shader *shader, nir_shader *nir,
return true;
}
} else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, si_map_io_driver_location,
NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem,
is_gfx9_mono_tcs ? NULL : si_map_io_driver_location,
key->ge.opt.same_patch_vertices, sel->info.tcs_vgpr_only_inputs);
/* Used by hs_emit_write_tess_factors() when monolithic shader. */
@ -3619,6 +3624,7 @@ nir_shader *si_get_prev_stage_nir_shader(struct si_shader *shader,
prev_shader->key.ge.as_ngg = key->ge.as_ngg;
}
prev_shader->next_shader = shader;
prev_shader->key.ge.mono = key->ge.mono;
prev_shader->key.ge.opt = key->ge.opt;
prev_shader->key.ge.opt.inline_uniforms = false; /* only TCS/GS can inline uniforms */

View file

@ -447,7 +447,6 @@ struct si_shader_info {
uint8_t clipdist_mask;
uint8_t culldist_mask;
uint16_t lshs_vertex_stride;
uint16_t esgs_vertex_stride;
uint16_t gsvs_vertex_size;
uint8_t gs_input_verts_per_prim;
@ -856,6 +855,7 @@ struct si_shader {
struct si_shader_selector *selector;
struct si_shader_selector *previous_stage_sel; /* for refcounting */
struct si_shader *next_shader; /* Only used during compilation of LS and ES when merged. */
struct si_shader_part *prolog;
struct si_shader *previous_stage; /* for GFX9 */
@ -1034,7 +1034,7 @@ unsigned si_determine_wave_size(struct si_screen *sscreen, struct si_shader *sha
void gfx9_get_gs_info(struct si_shader_selector *es, struct si_shader_selector *gs,
struct gfx9_gs_info *out);
bool gfx10_is_ngg_passthrough(struct si_shader *shader);
unsigned si_shader_lshs_vertex_stride(struct si_shader *ls);
bool si_should_clear_lds(struct si_screen *sscreen, const struct nir_shader *shader);
/* Inline helpers. */

View file

@ -837,14 +837,9 @@ void si_nir_scan_shader(struct si_screen *sscreen, const struct nir_shader *nir,
if (nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_CTRL ||
nir->info.stage == MESA_SHADER_TESS_EVAL) {
info->esgs_vertex_stride = info->lshs_vertex_stride =
info->esgs_vertex_stride =
util_last_bit64(info->outputs_written_before_tes_gs) * 16;
/* Add 1 dword to reduce LDS bank conflicts, so that each vertex
* will start on a different bank. (except for the maximum 32*16).
*/
info->lshs_vertex_stride += 4;
/* For the ESGS ring in LDS, add 1 dword to reduce LDS bank
* conflicts, i.e. each vertex will start on a different bank.
*/

View file

@ -4601,6 +4601,40 @@ static void si_set_patch_vertices(struct pipe_context *ctx, uint8_t patch_vertic
}
}
unsigned si_shader_lshs_vertex_stride(struct si_shader *ls)
{
unsigned num_slots;
if (ls->selector->stage == MESA_SHADER_VERTEX && !ls->next_shader) {
assert(ls->key.ge.as_ls);
assert(ls->selector->screen->info.gfx_level <= GFX8 || !ls->is_monolithic);
num_slots = util_last_bit64(ls->selector->info.outputs_written_before_tes_gs);
} else {
struct si_shader *tcs = ls->next_shader ? ls->next_shader : ls;
assert(tcs->selector->stage == MESA_SHADER_TESS_CTRL);
assert(tcs->selector->screen->info.gfx_level >= GFX9);
if (tcs->is_monolithic) {
uint64_t lds_inputs_read = tcs->selector->info.base.inputs_read;
/* Don't allocate LDS for inputs passed via VGPRs. */
if (tcs->key.ge.opt.same_patch_vertices)
lds_inputs_read &= ~tcs->selector->info.tcs_vgpr_only_inputs;
/* NIR lowering passes pack LS outputs/HS inputs if the usage masks of both are known. */
num_slots = util_bitcount64(lds_inputs_read);
} else {
num_slots = util_last_bit64(tcs->previous_stage_sel->info.outputs_written_before_tes_gs);
}
}
/* Add 1 dword to reduce LDS bank conflicts, so that each vertex starts on a different LDS
* bank.
*/
return num_slots ? num_slots * 16 + 4 : 0;
}
/**
* This calculates the LDS size for tessellation shaders (VS, TCS, TES).
* LS.LDS_SIZE is shared by all 3 shader stages.
@ -4618,7 +4652,6 @@ static void si_set_patch_vertices(struct pipe_context *ctx, uint8_t patch_vertic
void si_update_tess_io_layout_state(struct si_context *sctx)
{
struct si_shader *ls_current;
struct si_shader_selector *ls;
struct si_shader_selector *tcs = sctx->shader.tcs.cso;
unsigned tess_uses_primid = sctx->ia_multi_vgt_param_key.u.tess_uses_prim_id;
bool has_primid_instancing_bug = sctx->gfx_level == GFX6 && sctx->screen->info.max_se == 1;
@ -4630,10 +4663,8 @@ void si_update_tess_io_layout_state(struct si_context *sctx)
/* Since GFX9 has merged LS-HS in the TCS state, set LS = TCS. */
if (sctx->gfx_level >= GFX9) {
ls_current = sctx->shader.tcs.current;
ls = ls_current->key.ge.part.tcs.ls;
} else {
ls_current = sctx->shader.vs.current;
ls = sctx->shader.vs.cso;
if (!ls_current) {
sctx->do_update_shaders = true;
@ -4658,17 +4689,10 @@ void si_update_tess_io_layout_state(struct si_context *sctx)
unsigned num_tcs_output_cp = tcs->info.base.tess.tcs_vertices_out;
unsigned num_tcs_patch_outputs = util_last_bit64(tcs->info.patch_outputs_written);
unsigned input_vertex_size = ls->info.lshs_vertex_stride;
unsigned num_vs_outputs = (input_vertex_size - 4) / 16;
unsigned input_vertex_size = si_shader_lshs_vertex_stride(ls_current);
unsigned num_vs_outputs = input_vertex_size / 16;
unsigned output_vertex_size = num_tcs_outputs * 16;
unsigned input_patch_size;
/* Allocate LDS for TCS inputs only if it's used. */
if (!ls_current->key.ge.opt.same_patch_vertices ||
tcs->info.base.inputs_read & ~tcs->info.tcs_vgpr_only_inputs)
input_patch_size = num_tcs_input_cp * input_vertex_size;
else
input_patch_size = 0;
unsigned input_patch_size = num_tcs_input_cp * input_vertex_size;
unsigned pervertex_output_patch_size = num_tcs_output_cp * output_vertex_size;
unsigned output_patch_size = pervertex_output_patch_size + num_tcs_patch_outputs * 16;