From b7136d08901499d6dd1b55d2b039683259900a31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ol=C5=A1=C3=A1k?= Date: Fri, 23 Aug 2024 15:34:29 -0400 Subject: [PATCH] radeonsi: pass TCS inputs_read mask to LS output lowering on GFX9 + monolithic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This will allocate less LDS for LS outputs if there are holes between varyings when we have monolithic merged LS+TCS. (it removes the holes) There are 2 steps to this: - add helper si_shader_lshs_vertex_stride and use it everywhere - pass the TCS inputs_read bitmask instead of the "map" callback to si_lower_ls_outputs_mem Reviewed-by: Timur Kristóf Part-of: --- .../drivers/radeonsi/si_nir_lower_abi.c | 7 +-- src/gallium/drivers/radeonsi/si_shader.c | 14 ++++-- src/gallium/drivers/radeonsi/si_shader.h | 4 +- src/gallium/drivers/radeonsi/si_shader_info.c | 7 +-- .../drivers/radeonsi/si_state_shaders.cpp | 50 ++++++++++++++----- 5 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/gallium/drivers/radeonsi/si_nir_lower_abi.c b/src/gallium/drivers/radeonsi/si_nir_lower_abi.c index 01bf5351d52..a7a6ae71671 100644 --- a/src/gallium/drivers/radeonsi/si_nir_lower_abi.c +++ b/src/gallium/drivers/radeonsi/si_nir_lower_abi.c @@ -324,13 +324,14 @@ static bool lower_intrinsic(nir_builder *b, nir_instr *instr, struct lower_abi_s break; case nir_intrinsic_load_lshs_vertex_stride_amd: if (stage == MESA_SHADER_VERTEX) { - replacement = nir_imm_int(b, sel->info.lshs_vertex_stride); + replacement = nir_imm_int(b, si_shader_lshs_vertex_stride(shader)); } else if (stage == MESA_SHADER_TESS_CTRL) { if (sel->screen->info.gfx_level >= GFX9 && shader->is_monolithic) { - replacement = nir_imm_int(b, key->ge.part.tcs.ls->info.lshs_vertex_stride); + replacement = nir_imm_int(b, si_shader_lshs_vertex_stride(shader)); } else { nir_def *num_ls_out = ac_nir_unpack_arg(b, &args->ac, args->tcs_offchip_layout, 17, 6); - replacement = nir_iadd_imm_nuw(b, nir_ishl_imm(b, num_ls_out, 4), 4); + nir_def *extra_dw = nir_bcsel(b, nir_ieq_imm(b, num_ls_out, 0), nir_imm_int(b, 0), nir_imm_int(b, 4)); + replacement = nir_iadd_nuw(b, nir_ishl_imm(b, num_ls_out, 4), extra_dw); } } else { unreachable("no nir_load_lshs_vertex_stride_amd"); diff --git a/src/gallium/drivers/radeonsi/si_shader.c b/src/gallium/drivers/radeonsi/si_shader.c index 97f22f3ff2f..830a96ac82d 100644 --- a/src/gallium/drivers/radeonsi/si_shader.c +++ b/src/gallium/drivers/radeonsi/si_shader.c @@ -1307,7 +1307,7 @@ void si_shader_dump_stats_for_shader_db(struct si_screen *screen, struct si_shad * for performance and can be optimized. */ if (shader->key.ge.as_ls) - num_ls_outputs = shader->selector->info.lshs_vertex_stride / 16; + num_ls_outputs = si_shader_lshs_vertex_stride(shader) / 16; else if (shader->selector->stage == MESA_SHADER_TESS_CTRL) num_hs_outputs = util_last_bit64(shader->selector->info.outputs_written_before_tes_gs); else if (shader->key.ge.as_es) @@ -1843,11 +1843,15 @@ static bool si_lower_io_to_mem(struct si_shader *shader, nir_shader *nir, { struct si_shader_selector *sel = shader->selector; const union si_shader_key *key = &shader->key; + const bool is_gfx9_mono_tcs = sel->stage == MESA_SHADER_TESS_CTRL && shader->is_monolithic && + sel->screen->info.gfx_level >= GFX9; if (nir->info.stage == MESA_SHADER_VERTEX) { if (key->ge.as_ls) { - NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, si_map_io_driver_location, - key->ge.opt.same_patch_vertices, ~0ULL, tcs_vgpr_only_inputs); + NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, + is_gfx9_mono_tcs ? NULL : si_map_io_driver_location, + key->ge.opt.same_patch_vertices, + is_gfx9_mono_tcs ? sel->info.base.inputs_read : ~0ull, tcs_vgpr_only_inputs); return true; } else if (key->ge.as_es) { NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, si_map_io_driver_location, @@ -1855,7 +1859,8 @@ static bool si_lower_io_to_mem(struct si_shader *shader, nir_shader *nir, return true; } } else if (nir->info.stage == MESA_SHADER_TESS_CTRL) { - NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, si_map_io_driver_location, + NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, + is_gfx9_mono_tcs ? NULL : si_map_io_driver_location, key->ge.opt.same_patch_vertices, sel->info.tcs_vgpr_only_inputs); /* Used by hs_emit_write_tess_factors() when monolithic shader. */ @@ -3619,6 +3624,7 @@ nir_shader *si_get_prev_stage_nir_shader(struct si_shader *shader, prev_shader->key.ge.as_ngg = key->ge.as_ngg; } + prev_shader->next_shader = shader; prev_shader->key.ge.mono = key->ge.mono; prev_shader->key.ge.opt = key->ge.opt; prev_shader->key.ge.opt.inline_uniforms = false; /* only TCS/GS can inline uniforms */ diff --git a/src/gallium/drivers/radeonsi/si_shader.h b/src/gallium/drivers/radeonsi/si_shader.h index a0ba146f150..1b0b99e5e23 100644 --- a/src/gallium/drivers/radeonsi/si_shader.h +++ b/src/gallium/drivers/radeonsi/si_shader.h @@ -447,7 +447,6 @@ struct si_shader_info { uint8_t clipdist_mask; uint8_t culldist_mask; - uint16_t lshs_vertex_stride; uint16_t esgs_vertex_stride; uint16_t gsvs_vertex_size; uint8_t gs_input_verts_per_prim; @@ -856,6 +855,7 @@ struct si_shader { struct si_shader_selector *selector; struct si_shader_selector *previous_stage_sel; /* for refcounting */ + struct si_shader *next_shader; /* Only used during compilation of LS and ES when merged. */ struct si_shader_part *prolog; struct si_shader *previous_stage; /* for GFX9 */ @@ -1034,7 +1034,7 @@ unsigned si_determine_wave_size(struct si_screen *sscreen, struct si_shader *sha void gfx9_get_gs_info(struct si_shader_selector *es, struct si_shader_selector *gs, struct gfx9_gs_info *out); bool gfx10_is_ngg_passthrough(struct si_shader *shader); - +unsigned si_shader_lshs_vertex_stride(struct si_shader *ls); bool si_should_clear_lds(struct si_screen *sscreen, const struct nir_shader *shader); /* Inline helpers. */ diff --git a/src/gallium/drivers/radeonsi/si_shader_info.c b/src/gallium/drivers/radeonsi/si_shader_info.c index 2953e0a9b98..b0772a88010 100644 --- a/src/gallium/drivers/radeonsi/si_shader_info.c +++ b/src/gallium/drivers/radeonsi/si_shader_info.c @@ -837,14 +837,9 @@ void si_nir_scan_shader(struct si_screen *sscreen, const struct nir_shader *nir, if (nir->info.stage == MESA_SHADER_VERTEX || nir->info.stage == MESA_SHADER_TESS_CTRL || nir->info.stage == MESA_SHADER_TESS_EVAL) { - info->esgs_vertex_stride = info->lshs_vertex_stride = + info->esgs_vertex_stride = util_last_bit64(info->outputs_written_before_tes_gs) * 16; - /* Add 1 dword to reduce LDS bank conflicts, so that each vertex - * will start on a different bank. (except for the maximum 32*16). - */ - info->lshs_vertex_stride += 4; - /* For the ESGS ring in LDS, add 1 dword to reduce LDS bank * conflicts, i.e. each vertex will start on a different bank. */ diff --git a/src/gallium/drivers/radeonsi/si_state_shaders.cpp b/src/gallium/drivers/radeonsi/si_state_shaders.cpp index 9e01da90f59..dbbfec22018 100644 --- a/src/gallium/drivers/radeonsi/si_state_shaders.cpp +++ b/src/gallium/drivers/radeonsi/si_state_shaders.cpp @@ -4601,6 +4601,40 @@ static void si_set_patch_vertices(struct pipe_context *ctx, uint8_t patch_vertic } } +unsigned si_shader_lshs_vertex_stride(struct si_shader *ls) +{ + unsigned num_slots; + + if (ls->selector->stage == MESA_SHADER_VERTEX && !ls->next_shader) { + assert(ls->key.ge.as_ls); + assert(ls->selector->screen->info.gfx_level <= GFX8 || !ls->is_monolithic); + num_slots = util_last_bit64(ls->selector->info.outputs_written_before_tes_gs); + } else { + struct si_shader *tcs = ls->next_shader ? ls->next_shader : ls; + + assert(tcs->selector->stage == MESA_SHADER_TESS_CTRL); + assert(tcs->selector->screen->info.gfx_level >= GFX9); + + if (tcs->is_monolithic) { + uint64_t lds_inputs_read = tcs->selector->info.base.inputs_read; + + /* Don't allocate LDS for inputs passed via VGPRs. */ + if (tcs->key.ge.opt.same_patch_vertices) + lds_inputs_read &= ~tcs->selector->info.tcs_vgpr_only_inputs; + + /* NIR lowering passes pack LS outputs/HS inputs if the usage masks of both are known. */ + num_slots = util_bitcount64(lds_inputs_read); + } else { + num_slots = util_last_bit64(tcs->previous_stage_sel->info.outputs_written_before_tes_gs); + } + } + + /* Add 1 dword to reduce LDS bank conflicts, so that each vertex starts on a different LDS + * bank. + */ + return num_slots ? num_slots * 16 + 4 : 0; +} + /** * This calculates the LDS size for tessellation shaders (VS, TCS, TES). * LS.LDS_SIZE is shared by all 3 shader stages. @@ -4618,7 +4652,6 @@ static void si_set_patch_vertices(struct pipe_context *ctx, uint8_t patch_vertic void si_update_tess_io_layout_state(struct si_context *sctx) { struct si_shader *ls_current; - struct si_shader_selector *ls; struct si_shader_selector *tcs = sctx->shader.tcs.cso; unsigned tess_uses_primid = sctx->ia_multi_vgt_param_key.u.tess_uses_prim_id; bool has_primid_instancing_bug = sctx->gfx_level == GFX6 && sctx->screen->info.max_se == 1; @@ -4630,10 +4663,8 @@ void si_update_tess_io_layout_state(struct si_context *sctx) /* Since GFX9 has merged LS-HS in the TCS state, set LS = TCS. */ if (sctx->gfx_level >= GFX9) { ls_current = sctx->shader.tcs.current; - ls = ls_current->key.ge.part.tcs.ls; } else { ls_current = sctx->shader.vs.current; - ls = sctx->shader.vs.cso; if (!ls_current) { sctx->do_update_shaders = true; @@ -4658,17 +4689,10 @@ void si_update_tess_io_layout_state(struct si_context *sctx) unsigned num_tcs_output_cp = tcs->info.base.tess.tcs_vertices_out; unsigned num_tcs_patch_outputs = util_last_bit64(tcs->info.patch_outputs_written); - unsigned input_vertex_size = ls->info.lshs_vertex_stride; - unsigned num_vs_outputs = (input_vertex_size - 4) / 16; + unsigned input_vertex_size = si_shader_lshs_vertex_stride(ls_current); + unsigned num_vs_outputs = input_vertex_size / 16; unsigned output_vertex_size = num_tcs_outputs * 16; - unsigned input_patch_size; - - /* Allocate LDS for TCS inputs only if it's used. */ - if (!ls_current->key.ge.opt.same_patch_vertices || - tcs->info.base.inputs_read & ~tcs->info.tcs_vgpr_only_inputs) - input_patch_size = num_tcs_input_cp * input_vertex_size; - else - input_patch_size = 0; + unsigned input_patch_size = num_tcs_input_cp * input_vertex_size; unsigned pervertex_output_patch_size = num_tcs_output_cp * output_vertex_size; unsigned output_patch_size = pervertex_output_patch_size + num_tcs_patch_outputs * 16;