aco: Use workgroup size from input shader info.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12321>
This commit is contained in:
Timur Kristóf 2021-08-11 10:09:04 +02:00 committed by Marge Bot
parent c4ca08548b
commit 626b125857

View file

@ -1011,61 +1011,13 @@ setup_isel_context(Program* program, unsigned shader_count, struct nir_shader* c
ctx.options = args->options;
ctx.stage = program->stage;
/* TODO: Check if we need to adjust min_waves for unknown workgroup sizes. */
if (program->stage.hw == HWStage::VS || program->stage.hw == HWStage::FS) {
/* PS and legacy VS have separate waves, no workgroups */
program->workgroup_size = program->wave_size;
} else if (program->stage == compute_cs) {
/* CS sets the workgroup size explicitly */
program->workgroup_size = shaders[0]->info.workgroup_size[0] *
shaders[0]->info.workgroup_size[1] *
shaders[0]->info.workgroup_size[2];
} else if (program->stage.hw == HWStage::ES || program->stage == geometry_gs) {
/* Unmerged ESGS operate in workgroups if on-chip GS (LDS rings) are enabled on GFX7-8
* (not implemented in Mesa) */
program->workgroup_size = program->wave_size;
} else if (program->stage.hw == HWStage::GS) {
/* If on-chip GS (LDS rings) are enabled on GFX9 or later, merged GS operates in workgroups */
assert(program->chip_class >= GFX9);
uint32_t es_verts_per_subgrp =
G_028A44_ES_VERTS_PER_SUBGRP(program->info->gs_ring_info.vgt_gs_onchip_cntl);
uint32_t gs_instr_prims_in_subgrp =
G_028A44_GS_INST_PRIMS_IN_SUBGRP(program->info->gs_ring_info.vgt_gs_onchip_cntl);
uint32_t workgroup_size = MAX2(es_verts_per_subgrp, gs_instr_prims_in_subgrp);
program->workgroup_size = MAX2(MIN2(workgroup_size, 256), 1);
} else if (program->stage == vertex_ls) {
/* Unmerged LS operates in workgroups */
program->workgroup_size = UINT_MAX; /* TODO: probably tcs_num_patches * tcs_vertices_in, but
those are not plumbed to ACO for LS */
} else if (program->stage == tess_control_hs) {
/* Unmerged HS operates in workgroups, size is determined by the output vertices */
program->workgroup_size = args->shader_info->workgroup_size;
assert(program->workgroup_size);
if (ctx.stage == tess_control_hs)
setup_tcs_info(&ctx, shaders[0], NULL);
program->workgroup_size = ctx.tcs_num_patches * shaders[0]->info.tess.tcs_vertices_out;
} else if (program->stage == vertex_tess_control_hs) {
/* Merged LSHS operates in workgroups, but can still have a different number of LS and HS
* invocations */
else if (ctx.stage == vertex_tess_control_hs)
setup_tcs_info(&ctx, shaders[1], shaders[0]);
program->workgroup_size =
ctx.tcs_num_patches *
MAX2(shaders[1]->info.tess.tcs_vertices_out, ctx.args->options->key.tcs.input_vertices);
} else if (program->stage.hw == HWStage::NGG) {
gfx10_ngg_info& ngg_info = args->shader_info->ngg_info;
unsigned num_gs_invocations =
(program->stage.has(SWStage::GS)) ? MAX2(shaders[1]->info.gs.invocations, 1) : 1;
/* Max ES (SW VS/TES) threads */
uint32_t max_esverts = ngg_info.hw_max_esverts;
/* Max GS input primitives = max GS threads */
uint32_t max_gs_input_prims = ngg_info.max_gsprims * num_gs_invocations;
/* Maximum output vertices -- each thread can export only 1 vertex */
uint32_t max_out_vtx = ngg_info.max_out_verts;
/* Maximum output primitives -- each thread can export only 1 or 0 primitive */
uint32_t max_out_prm = ngg_info.max_gsprims * num_gs_invocations * ngg_info.prim_amp_factor;
program->workgroup_size = MAX4(max_esverts, max_gs_input_prims, max_out_vtx, max_out_prm);
} else {
unreachable("Unsupported shader stage.");
}
calc_min_waves(program);