ac/nir: Store mesh shader API and HW workgroup size in lowering state.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15199>
This commit is contained in:
Timur Kristóf 2022-02-27 18:39:01 +01:00 committed by Marge Bot
parent d0f45c7c49
commit 57775dd76a

View file

@ -107,6 +107,8 @@ typedef struct
unsigned prim_vtx_indices_addr;
unsigned numprims_lds_addr;
unsigned wave_size;
unsigned api_workgroup_size;
unsigned hw_workgroup_size;
struct {
/* Bitmask of components used: 4 bits per slot, 1 bit per component. */
@ -2457,12 +2459,12 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
}
static void
handle_smaller_ms_api_workgroup(nir_function_impl *impl,
nir_builder *b,
unsigned api_workgroup_size,
unsigned hw_workgroup_size,
handle_smaller_ms_api_workgroup(nir_builder *b,
lower_ngg_ms_state *s)
{
if (s->api_workgroup_size >= s->hw_workgroup_size)
return;
/* Handle barriers manually when the API workgroup
* size is less than the HW workgroup size.
*
@ -2478,19 +2480,19 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl,
* all. In this case, we emit code that consumes every
* barrier on the extra waves.
*/
assert(hw_workgroup_size % s->wave_size == 0);
bool scan_barriers = ALIGN(api_workgroup_size, s->wave_size) < hw_workgroup_size;
bool can_shrink_barriers = api_workgroup_size <= s->wave_size;
assert(s->hw_workgroup_size % s->wave_size == 0);
bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12;
unsigned num_api_waves = DIV_ROUND_UP(api_workgroup_size, s->wave_size);
unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
/* Scan the shader for workgroup barriers. */
if (scan_barriers) {
bool has_any_workgroup_barriers = false;
nir_foreach_block(block, impl) {
nir_foreach_block(block, b->impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
@ -2521,8 +2523,8 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl,
/* Extract the full control flow of the shader. */
nir_cf_list extracted;
nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
b->cursor = nir_before_cf_list(&impl->body);
nir_cf_extract(&extracted, nir_before_cf_list(&b->impl->body), nir_after_cf_list(&b->impl->body));
b->cursor = nir_before_cf_list(&b->impl->body);
/* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
@ -2542,7 +2544,7 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl,
.memory_modes = nir_var_shader_out | nir_var_mem_shared);
}
nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, api_workgroup_size));
nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, s->api_workgroup_size));
nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
{
nir_cf_reinsert(&extracted, b->cursor);
@ -2638,19 +2640,6 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
shader->info.shared_size = prim_vtx_indices_addr + prim_vtx_indices_size;
lower_ngg_ms_state state = {
.wave_size = wave_size,
.per_vertex_outputs = per_vertex_outputs,
.per_primitive_outputs = per_primitive_outputs,
.num_per_vertex_outputs = num_per_vertex_outputs,
.num_per_primitive_outputs = num_per_primitive_outputs,
.vertices_per_prim = vertices_per_prim,
.vertex_attr_lds_addr = vertex_attr_lds_addr,
.prim_attr_lds_addr = prim_attr_lds_addr,
.prim_vtx_indices_addr = prim_vtx_indices_addr,
.numprims_lds_addr = numprims_lds_addr,
};
/* The workgroup size that is specified by the API shader may be different
* from the size of the workgroup that actually runs on the HW, due to the
* limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
@ -2665,6 +2654,21 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
unsigned hw_workgroup_size =
ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
lower_ngg_ms_state state = {
.wave_size = wave_size,
.per_vertex_outputs = per_vertex_outputs,
.per_primitive_outputs = per_primitive_outputs,
.num_per_vertex_outputs = num_per_vertex_outputs,
.num_per_primitive_outputs = num_per_primitive_outputs,
.vertices_per_prim = vertices_per_prim,
.vertex_attr_lds_addr = vertex_attr_lds_addr,
.prim_attr_lds_addr = prim_attr_lds_addr,
.prim_vtx_indices_addr = prim_vtx_indices_addr,
.numprims_lds_addr = numprims_lds_addr,
.api_workgroup_size = api_workgroup_size,
.hw_workgroup_size = hw_workgroup_size,
};
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
@ -2673,9 +2677,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
nir_builder_init(b, impl);
b->cursor = nir_before_cf_list(&impl->body);
if (api_workgroup_size < hw_workgroup_size) {
handle_smaller_ms_api_workgroup(impl, b, api_workgroup_size, hw_workgroup_size, &state);
}
handle_smaller_ms_api_workgroup(b, &state);
lower_ms_intrinsics(shader, &state);
emit_ms_finale(b, &state);