ac/nir: implement mesh shader gs_fast_launch=2

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25040>
This commit is contained in:
Rhys Perry 2023-09-01 11:24:56 +01:00 committed by Marge Bot
parent 75bc2e7149
commit cddbe9a4c2
3 changed files with 17 additions and 8 deletions

View file

@ -197,7 +197,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
bool *out_needs_scratch_ring,
unsigned wave_size,
bool multiview,
bool has_query);
bool has_query,
bool fast_launch_2);
void
ac_nir_lower_task_outputs_to_mem(nir_shader *shader,

View file

@ -197,6 +197,7 @@ typedef struct
typedef struct
{
enum amd_gfx_level gfx_level;
bool fast_launch_2;
ms_out_mem_layout layout;
uint64_t per_vertex_outputs;
@ -4513,6 +4514,10 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
ms_prim_gen_query(b, invocation_index, num_prm, s);
nir_def *row_start = NULL;
if (s->fast_launch_2)
row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
/* Load vertex/primitive attributes from shared memory and
* emit store_output intrinsics for them.
*
@ -4544,7 +4549,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
{
emit_ms_vertex(b, invocation_index, NULL, !wait_attr_ring, true, per_vertex_outputs, s);
emit_ms_vertex(b, invocation_index, row_start, !wait_attr_ring, true, per_vertex_outputs, s);
}
nir_pop_if(b, if_has_output_vertex);
}
@ -4554,7 +4559,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
{
emit_ms_primitive(b, invocation_index, NULL, !wait_attr_ring, true, per_primitive_outputs, s);
emit_ms_primitive(b, invocation_index, row_start, !wait_attr_ring, true, per_primitive_outputs, s);
}
nir_pop_if(b, if_has_output_primitive);
}
@ -4574,14 +4579,14 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
{
emit_ms_vertex(b, invocation_index, NULL, true, false, per_vertex_outputs, s);
emit_ms_vertex(b, invocation_index, row_start, true, false, per_vertex_outputs, s);
}
nir_pop_if(b, if_has_output_vertex);
nir_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
{
emit_ms_primitive(b, invocation_index, NULL, true, false, per_primitive_outputs, s);
emit_ms_primitive(b, invocation_index, row_start, true, false, per_primitive_outputs, s);
}
nir_pop_if(b, if_has_output_primitive);
}
@ -4866,7 +4871,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
bool *out_needs_scratch_ring,
unsigned wave_size,
bool multiview,
bool has_query)
bool has_query,
bool fast_launch_2)
{
unsigned vertices_per_prim =
num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
@ -4917,6 +4923,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
.insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
.uses_cull_flags = uses_cull,
.gfx_level = gfx_level,
.fast_launch_2 = fast_launch_2,
.clipdist_enable_mask = clipdist_enable_mask,
.vs_output_param_offset = vs_output_param_offset,
.has_param_exports = has_param_exports,
@ -4935,7 +4942,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
nir_builder *b = &builder; /* This is to avoid the & */
handle_smaller_ms_api_workgroup(b, &state);
ms_emit_legacy_workgroup_index(b, &state);
if (!fast_launch_2)
ms_emit_legacy_workgroup_index(b, &state);
ms_create_same_invocation_vars(b, &state);
nir_metadata_preserve(impl, nir_metadata_none);

View file

@ -916,7 +916,7 @@ radv_lower_ngg(struct radv_device *device, struct radv_shader_stage *ngg_stage,
bool scratch_ring = false;
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, options.gfx_level, options.clipdist_enable_mask,
options.vs_output_param_offset, options.has_param_exports, &scratch_ring, info->wave_size,
pl_key->has_multiview_view_index, info->ms.has_query);
pl_key->has_multiview_view_index, info->ms.has_query, false);
ngg_stage->info.ms.needs_ms_scratch_ring = scratch_ring;
} else {
unreachable("invalid SW stage passed to radv_lower_ngg");