radeonsi: init pm4 state for mesh shader

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37505>
This commit is contained in:
Qiang Yu 2025-04-23 16:13:33 +08:00 committed by Marge Bot
parent ce6a1e7563
commit 74894150f1
6 changed files with 117 additions and 13 deletions

View file

@ -127,6 +127,11 @@ enum si_has_ngg {
NGG_ON,
};
enum si_has_ms {
MS_OFF,
MS_ON,
};
#define DCC_CODE(x) (((x) << 24) | ((x) << 16) | ((x) << 8) | (x))
enum si_clear_code

View file

@ -848,6 +848,9 @@ struct si_shader {
unsigned spi_shader_pgm_rsrc3_gs;
unsigned spi_shader_pgm_rsrc4_gs;
unsigned vgt_shader_stages_en;
unsigned spi_shader_gs_meshlet_dim;
unsigned spi_shader_gs_meshlet_exp_alloc;
unsigned spi_shader_gs_meshlet_ctrl;
} ngg;
struct {
@ -916,6 +919,7 @@ struct nir_shader *si_deserialize_shader(struct si_shader_selector *sel);
unsigned si_get_ps_num_interp(struct si_shader *ps);
unsigned si_get_shader_prefetch_size(struct si_shader *shader);
unsigned si_get_shader_binary_size(struct si_screen *screen, struct si_shader *shader);
unsigned si_get_max_workgroup_size(const struct si_shader *shader);
/* si_shader_info.c */
void si_nir_scan_shader(struct si_screen *sscreen, struct nir_shader *nir,

View file

@ -221,6 +221,7 @@ struct si_shader_variant_info {
bool writes_stencil : 1;
bool writes_sample_mask : 1;
bool uses_discard : 1;
bool uses_mesh_scratch_ring : 1;
uint8_t nr_pos_exports;
uint8_t nr_param_exports;
uint8_t clipdist_mask;

View file

@ -95,7 +95,6 @@ typedef struct nir_shader nir_shader;
/* si_shader.c */
bool si_is_multi_part_shader(struct si_shader *shader);
bool si_is_merged_shader(struct si_shader *shader);
unsigned si_get_max_workgroup_size(const struct si_shader *shader);
enum ac_hw_stage si_select_hw_stage(const mesa_shader_stage stage, const union si_shader_key *const key,
const enum amd_gfx_level gfx_level);
bool gfx10_ngg_export_prim_early(struct si_shader *shader);

View file

@ -429,6 +429,11 @@ enum si_tracked_reg
SI_TRACKED_COMPUTE_DISPATCH_SCRATCH_BASE_LO, /* GFX11+ */
SI_TRACKED_COMPUTE_DISPATCH_SCRATCH_BASE_HI, /* GFX11+ */
/* 3 consecutive registers. */
SI_TRACKED_SPI_SHADER_GS_MESHLET_DIM, /* GFX11+ */
SI_TRACKED_SPI_SHADER_GS_MESHLET_EXP_ALLOC, /* GFX11+ */
SI_TRACKED_SPI_SHADER_GS_MESHLET_CTRL, /* GFX12+ */
SI_NUM_ALL_TRACKED_REGS,
};

View file

@ -1173,7 +1173,7 @@ static void gfx10_emit_shader_ngg(struct si_context *sctx, unsigned index)
radeon_end();
}
template <enum si_has_tess HAS_TESS>
template <enum si_has_tess HAS_TESS, enum si_has_ms HAS_MS>
static void gfx11_dgpu_emit_shader_ngg(struct si_context *sctx, unsigned index)
{
struct si_shader *shader = sctx->queued.named.gs;
@ -1214,6 +1214,14 @@ static void gfx11_dgpu_emit_shader_ngg(struct si_context *sctx, unsigned index)
gfx11_opt_push_gfx_sh_reg(R_00B204_SPI_SHADER_PGM_RSRC4_GS,
SI_TRACKED_SPI_SHADER_PGM_RSRC4_GS,
shader->ngg.spi_shader_pgm_rsrc4_gs);
if (HAS_MS) {
gfx11_opt_push_gfx_sh_reg(R_00B2B0_SPI_SHADER_GS_MESHLET_DIM,
SI_TRACKED_SPI_SHADER_GS_MESHLET_DIM,
shader->ngg.spi_shader_gs_meshlet_dim);
gfx11_opt_push_gfx_sh_reg(R_00B2B4_SPI_SHADER_GS_MESHLET_EXP_ALLOC,
SI_TRACKED_SPI_SHADER_GS_MESHLET_EXP_ALLOC,
shader->ngg.spi_shader_gs_meshlet_exp_alloc);
}
} else {
if (sctx->screen->info.uses_kernel_cu_mask) {
radeon_opt_set_sh_reg_idx(R_00B21C_SPI_SHADER_PGM_RSRC3_GS,
@ -1230,6 +1238,12 @@ static void gfx11_dgpu_emit_shader_ngg(struct si_context *sctx, unsigned index)
SI_TRACKED_SPI_SHADER_PGM_RSRC4_GS,
shader->ngg.spi_shader_pgm_rsrc4_gs);
}
if (HAS_MS) {
radeon_opt_set_sh_reg2(R_00B2B0_SPI_SHADER_GS_MESHLET_DIM,
SI_TRACKED_SPI_SHADER_GS_MESHLET_DIM,
shader->ngg.spi_shader_gs_meshlet_dim,
shader->ngg.spi_shader_gs_meshlet_exp_alloc);
}
}
radeon_opt_set_uconfig_reg(R_030980_GE_PC_ALLOC, SI_TRACKED_GE_PC_ALLOC,
@ -1237,7 +1251,7 @@ static void gfx11_dgpu_emit_shader_ngg(struct si_context *sctx, unsigned index)
radeon_end();
}
template <enum si_has_tess HAS_TESS>
template <enum si_has_tess HAS_TESS, enum si_has_ms HAS_MS>
static void gfx12_emit_shader_ngg(struct si_context *sctx, unsigned index)
{
struct si_shader *shader = sctx->queued.named.gs;
@ -1275,6 +1289,17 @@ static void gfx12_emit_shader_ngg(struct si_context *sctx, unsigned index)
gfx12_opt_push_gfx_sh_reg(R_00B220_SPI_SHADER_PGM_RSRC4_GS,
SI_TRACKED_SPI_SHADER_PGM_RSRC4_GS,
shader->ngg.spi_shader_pgm_rsrc4_gs);
if (HAS_MS) {
gfx12_opt_push_gfx_sh_reg(R_00B2B0_SPI_SHADER_GS_MESHLET_DIM,
SI_TRACKED_SPI_SHADER_GS_MESHLET_DIM,
shader->ngg.spi_shader_gs_meshlet_dim);
gfx12_opt_push_gfx_sh_reg(R_00B2B4_SPI_SHADER_GS_MESHLET_EXP_ALLOC,
SI_TRACKED_SPI_SHADER_GS_MESHLET_EXP_ALLOC,
shader->ngg.spi_shader_gs_meshlet_exp_alloc);
gfx12_opt_push_gfx_sh_reg(R_00B2B8_SPI_SHADER_GS_MESHLET_CTRL,
SI_TRACKED_SPI_SHADER_GS_MESHLET_CTRL,
shader->ngg.spi_shader_gs_meshlet_ctrl);
}
}
unsigned si_get_input_prim(const struct si_shader_selector *gs, const union si_shader_key *key,
@ -1291,6 +1316,10 @@ unsigned si_get_input_prim(const struct si_shader_selector *gs, const union si_s
return MESA_PRIM_TRIANGLES;
}
/* Just fake to be points input for NGG calculation. */
if (gs->stage == MESA_SHADER_MESH)
return MESA_PRIM_POINTS;
assert(gs->stage == MESA_SHADER_VERTEX);
if (key->ge.opt.ngg_culling & SI_NGG_CULL_VS_LINES)
@ -1396,7 +1425,7 @@ unsigned si_shader_num_alloc_param_exports(struct si_shader *shader)
}
/**
* Prepare the PM4 image for \p shader, which will run as a merged ESGS shader
* Prepare the PM4 image for \p shader, which will run as a merged ESGS or MS shader
* in NGG mode.
*/
static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader)
@ -1426,14 +1455,18 @@ static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader
if (sscreen->info.gfx_level >= GFX12) {
if (es_stage == MESA_SHADER_TESS_EVAL)
pm4->atom.emit = gfx12_emit_shader_ngg<TESS_ON>;
pm4->atom.emit = gfx12_emit_shader_ngg<TESS_ON, MS_OFF>;
else if (gs_stage == MESA_SHADER_MESH)
pm4->atom.emit = gfx12_emit_shader_ngg<TESS_OFF, MS_ON>;
else
pm4->atom.emit = gfx12_emit_shader_ngg<TESS_OFF>;
pm4->atom.emit = gfx12_emit_shader_ngg<TESS_OFF, MS_OFF>;
} else if (sscreen->info.has_set_context_pairs_packed) {
if (es_stage == MESA_SHADER_TESS_EVAL)
pm4->atom.emit = gfx11_dgpu_emit_shader_ngg<TESS_ON>;
pm4->atom.emit = gfx11_dgpu_emit_shader_ngg<TESS_ON, MS_OFF>;
else if (gs_stage == MESA_SHADER_MESH)
pm4->atom.emit = gfx11_dgpu_emit_shader_ngg<TESS_OFF, MS_ON>;
else
pm4->atom.emit = gfx11_dgpu_emit_shader_ngg<TESS_OFF>;
pm4->atom.emit = gfx11_dgpu_emit_shader_ngg<TESS_OFF, MS_OFF>;
} else {
if (es_stage == MESA_SHADER_TESS_EVAL)
pm4->atom.emit = gfx10_emit_shader_ngg<TESS_ON>;
@ -1452,6 +1485,21 @@ static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader
} else {
num_user_sgprs = si_get_num_vs_user_sgprs(shader, GFX9_GS_NUM_USER_SGPR);
}
} else if (es_stage == MESA_SHADER_MESH) {
es_vgpr_comp_cnt = 0;
num_user_sgprs = GFX11_SGPR_MS_ATTRIBUTE_RING_ADDR;
if (sscreen->info.gfx_level >= GFX11)
num_user_sgprs++;
/* task ring entry */
num_user_sgprs++;
if (gs_sel->info.base.task_payload_size)
num_user_sgprs++;
if (shader->info.uses_draw_id)
num_user_sgprs++;
if (gs_sel->info.uses_grid_size || sscreen->info.gfx_level < GFX11)
num_user_sgprs += 3;
if (shader->info.uses_mesh_scratch_ring)
num_user_sgprs++;
} else {
assert(es_stage == MESA_SHADER_TESS_EVAL);
es_vgpr_comp_cnt = es_enable_prim_id ? 3 : 2;
@ -1532,6 +1580,10 @@ static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader
shader->ngg.esgs_vertex_stride = es_sel->info.esgs_vertex_stride / 4;
shader->ngg.vgt_gs_max_vert_out = gs_sel->info.base.gs.vertices_out;
shader->ngg.ge_ngg_subgrp_cntl = S_028B4C_PRIM_AMP_FACTOR(gs_sel->info.base.gs.vertices_out);
} else if (gs_stage == MESA_SHADER_MESH) {
shader->ngg.vgt_gs_max_vert_out = sscreen->info.mesh_fast_launch_2 ?
gs_info->base.mesh.max_vertices_out : si_get_max_workgroup_size(shader);
shader->ngg.ge_ngg_subgrp_cntl = gs_info->base.mesh.max_primitives_out;
} else {
shader->ngg.esgs_vertex_stride = 1;
shader->ngg.vgt_gs_max_vert_out = 1;
@ -1671,26 +1723,61 @@ static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader
S_028818_VPORT_Z_SCALE_ENA(1) | S_028818_VPORT_Z_OFFSET_ENA(1);
}
bool ngg_wave_id_en =
shader->info.num_streamout_vec4s != 0 || shader->info.uses_mesh_scratch_ring;
if (sscreen->info.gfx_level >= GFX12) {
shader->ngg.vgt_shader_stages_en =
S_028A98_GS_EN(gs_stage == MESA_SHADER_GEOMETRY) |
S_028A98_GS_FAST_LAUNCH(gs_stage == MESA_SHADER_MESH) |
S_028A98_PRIMGEN_PASSTHRU_NO_MSG(gfx10_is_ngg_passthrough(shader)) |
S_028A98_GS_W32_EN(shader->wave_size == 32) |
S_028A98_NGG_WAVE_ID_EN(shader->info.num_streamout_vec4s != 0);
S_028A98_NGG_WAVE_ID_EN(ngg_wave_id_en);
} else {
shader->ngg.vgt_shader_stages_en =
S_028B54_ES_EN(es_stage == MESA_SHADER_TESS_EVAL ?
if (gs_stage == MESA_SHADER_MESH) {
shader->ngg.vgt_shader_stages_en =
S_028B54_GS_EN(1) |
S_028B54_GS_FAST_LAUNCH(sscreen->info.mesh_fast_launch_2 ? 2 : 1);
} else {
shader->ngg.vgt_shader_stages_en =
S_028B54_ES_EN(es_stage == MESA_SHADER_TESS_EVAL ?
V_028B54_ES_STAGE_DS : V_028B54_ES_STAGE_REAL) |
S_028B54_GS_EN(gs_stage == MESA_SHADER_GEOMETRY) |
S_028B54_GS_EN(gs_stage == MESA_SHADER_GEOMETRY);
}
shader->ngg.vgt_shader_stages_en |=
S_028B54_PRIMGEN_EN(1) |
S_028B54_PRIMGEN_PASSTHRU_EN(gfx10_is_ngg_passthrough(shader)) |
S_028B54_PRIMGEN_PASSTHRU_NO_MSG(gfx10_is_ngg_passthrough(shader) &&
sscreen->info.family >= CHIP_NAVI23) |
S_028B54_NGG_WAVE_ID_EN(shader->info.num_streamout_vec4s != 0) |
S_028B54_NGG_WAVE_ID_EN(ngg_wave_id_en) |
S_028B54_GS_W32_EN(shader->wave_size == 32) |
S_028B54_MAX_PRIMGRP_IN_WAVE(2);
}
if (gs_stage == MESA_SHADER_MESH && sscreen->info.mesh_fast_launch_2) {
unsigned workgroup_threads =
gs_info->base.workgroup_size[0] *
gs_info->base.workgroup_size[1] *
gs_info->base.workgroup_size[2];
shader->ngg.spi_shader_gs_meshlet_dim =
S_00B2B0_MESHLET_NUM_THREAD_X(gs_info->base.workgroup_size[0] - 1) |
S_00B2B0_MESHLET_NUM_THREAD_Y(gs_info->base.workgroup_size[1] - 1) |
S_00B2B0_MESHLET_NUM_THREAD_Z(gs_info->base.workgroup_size[2] - 1) |
S_00B2B0_MESHLET_THREADGROUP_SIZE(workgroup_threads - 1);
shader->ngg.spi_shader_gs_meshlet_exp_alloc =
S_00B2B4_MAX_EXP_VERTS(gs_info->base.mesh.max_vertices_out) |
S_00B2B4_MAX_EXP_PRIMS(gs_info->base.mesh.max_primitives_out);
if (sscreen->info.gfx_level >= GFX12) {
const bool derivative_group_quads =
gs_info->base.derivative_group == DERIVATIVE_GROUP_QUADS;
shader->ngg.spi_shader_gs_meshlet_ctrl =
S_00B2B8_INTERLEAVE_BITS_X(derivative_group_quads) |
S_00B2B8_INTERLEAVE_BITS_Y(derivative_group_quads);
}
}
ac_pm4_finalize(&pm4->base);
}
@ -2251,6 +2338,9 @@ static void si_shader_init_pm4_state(struct si_screen *sscreen, struct si_shader
case MESA_SHADER_FRAGMENT:
si_shader_ps(sscreen, shader);
break;
case MESA_SHADER_MESH:
gfx10_shader_ngg(sscreen, shader);
break;
default:
assert(0);
}