diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 2ec9d60eda3..96a14cc1bb9 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -2453,25 +2453,16 @@ emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s) } static void -emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) +set_nv_ms_final_output_counts(nir_builder *b, + lower_ngg_ms_state *s, + nir_ssa_def **out_num_prm, + nir_ssa_def **out_num_vtx) { - /* We assume there is always a single end block in the shader. */ - nir_block *last_block = nir_impl_last_block(b->impl); - b->cursor = nir_after_block(last_block); - - nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, - .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared); - /* Limitations of the NV extension: * - Number of primitives can be written and read by any invocation, * so we have to store/load it to/from LDS to make sure the general case works. * - Number of vertices is not actually known, so we just always use the * maximum number here. - * - * TODO: in a possible cross-vendor extension we expect to be able do this smarter: - * - Lower SetMeshOutputCounts (not present in NV) directly to alloc_vertices_and_primitives. - * - We'll know the exact number of output vertices. - * - No longer need to ensure that these variables are readable by any invocation. */ nir_ssa_def *loaded_num_prm; nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32); @@ -2496,6 +2487,25 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) } nir_pop_if(b, if_wave_0); + *out_num_prm = num_prm; + *out_num_vtx = num_vtx; +} + +static void +emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) +{ + /* We assume there is always a single end block in the shader. */ + nir_block *last_block = nir_impl_last_block(b->impl); + b->cursor = nir_after_block(last_block); + + nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, + .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared); + + nir_ssa_def *num_prm; + nir_ssa_def *num_vtx; + + set_nv_ms_final_output_counts(b, s, &num_prm, &num_vtx); + nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); /* Load vertex/primitive attributes from shared memory and