From 57dec0678e48a10df756f36f910decd1ec0552a7 Mon Sep 17 00:00:00 2001 From: Samuel Pitoiset Date: Wed, 20 Sep 2023 17:03:29 +0200 Subject: [PATCH] ac/nir: add lowering for mesh shader queries Signed-off-by: Samuel Pitoiset Part-of: --- src/amd/common/ac_nir.h | 3 +- src/amd/common/ac_nir_lower_ngg.c | 50 ++++++++++++++++++++++++++++++- src/amd/vulkan/radv_shader.c | 2 +- 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h index c9072c863bd..7c560e3b563 100644 --- a/src/amd/common/ac_nir.h +++ b/src/amd/common/ac_nir.h @@ -195,7 +195,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader, bool has_param_exports, bool *out_needs_scratch_ring, unsigned wave_size, - bool multiview); + bool multiview, + bool has_query); void ac_nir_lower_task_outputs_to_mem(nir_shader *shader, diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index e4a850a09ec..0872c1e2f15 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -227,6 +227,9 @@ typedef struct uint32_t clipdist_enable_mask; const uint8_t *vs_output_param_offset; bool has_param_exports; + + /* True if the lowering needs to insert shader query. */ + bool has_query; } lower_ngg_ms_state; /* Per-vertex LDS layout of culling shaders */ @@ -4401,6 +4404,45 @@ ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s return prim_exp_arg_ch2; } +static void +ms_prim_gen_query(nir_builder *b, + nir_def *invocation_index, + nir_def *num_prm, + lower_ngg_ms_state *s) +{ + if (!s->has_query) + return; + + nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); + { + nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b)); + { + nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0); + } + nir_pop_if(b, if_shader_query); + } + nir_pop_if(b, if_invocation_index_zero); +} + +static void +ms_invocation_query(nir_builder *b, + nir_def *invocation_index, + lower_ngg_ms_state *s) +{ + if (!s->has_query) + return; + + nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); + { + nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b)); + { + nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size)); + } + nir_pop_if(b, if_pipeline_query); + } + nir_pop_if(b, if_invocation_index_zero); +} + static void ms_emit_primitive_export(nir_builder *b, nir_def *invocation_index, @@ -4435,6 +4477,8 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) nir_def *invocation_index = nir_load_local_invocation_index(b); + ms_prim_gen_query(b, invocation_index, num_prm, s); + /* Load vertex/primitive attributes from shared memory and * emit store_output intrinsics for them. * @@ -4664,6 +4708,8 @@ handle_smaller_ms_api_workgroup(nir_builder *b, .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_shader_out | nir_var_mem_shared); } + + ms_invocation_query(b, invocation_index, s); } nir_pop_if(b, if_has_api_ms_invocation); @@ -4832,7 +4878,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader, bool has_param_exports, bool *out_needs_scratch_ring, unsigned wave_size, - bool multiview) + bool multiview, + bool has_query) { unsigned vertices_per_prim = num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type); @@ -4886,6 +4933,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader, .clipdist_enable_mask = clipdist_enable_mask, .vs_output_param_offset = vs_output_param_offset, .has_param_exports = has_param_exports, + .has_query = has_query, }; nir_function_impl *impl = nir_shader_get_entrypoint(shader); diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index eecfe43acf2..4d718652425 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -912,7 +912,7 @@ radv_lower_ngg(struct radv_device *device, struct radv_shader_stage *ngg_stage, bool scratch_ring = false; NIR_PASS_V(nir, ac_nir_lower_ngg_ms, options.gfx_level, options.clipdist_enable_mask, options.vs_output_param_offset, options.has_param_exports, &scratch_ring, info->wave_size, - pl_key->has_multiview_view_index); + pl_key->has_multiview_view_index, false); ngg_stage->info.ms.needs_ms_scratch_ring = scratch_ring; } else { unreachable("invalid SW stage passed to radv_lower_ngg");