diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index ca7498ea1c7..583a1c1b222 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -2417,13 +2417,153 @@ emit_ms_finale(nir_shader *shader, lower_ngg_ms_state *s) nir_pop_if(b, if_has_output_primitive); } +static void +handle_smaller_ms_api_workgroup(nir_function_impl *impl, + nir_builder *b, + unsigned api_workgroup_size, + unsigned hw_workgroup_size, + lower_ngg_ms_state *s) +{ + /* Handle barriers manually when the API workgroup + * size is less than the HW workgroup size. + * + * The problem is that the real workgroup launched on NGG HW + * will be larger than the size specified by the API, and the + * extra waves need to keep up with barriers in the API waves. + * + * There are 2 different cases: + * 1. The whole API workgroup fits in a single wave. + * We can shrink the barriers to subgroup scope and + * don't need to insert any extra ones. + * 2. The API workgroup occupies multiple waves, but not + * all. In this case, we emit code that consumes every + * barrier on the extra waves. + */ + assert(hw_workgroup_size % s->wave_size == 0); + bool scan_barriers = ALIGN(api_workgroup_size, s->wave_size) < hw_workgroup_size; + bool can_shrink_barriers = api_workgroup_size <= s->wave_size; + bool need_additional_barriers = scan_barriers && !can_shrink_barriers; + + unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12; + unsigned num_api_waves = DIV_ROUND_UP(api_workgroup_size, s->wave_size); + + /* Scan the shader for workgroup barriers. */ + if (scan_barriers) { + bool has_any_workgroup_barriers = false; + + nir_foreach_block(block, impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + bool is_workgroup_barrier = + intrin->intrinsic == nir_intrinsic_scoped_barrier && + nir_intrinsic_execution_scope(intrin) == NIR_SCOPE_WORKGROUP; + + if (!is_workgroup_barrier) + continue; + + if (can_shrink_barriers) { + /* Every API invocation runs in the first wave. + * In this case, we can change the barriers to subgroup scope + * and avoid adding additional barriers. + */ + nir_intrinsic_set_memory_scope(intrin, NIR_SCOPE_SUBGROUP); + nir_intrinsic_set_execution_scope(intrin, NIR_SCOPE_SUBGROUP); + } else { + has_any_workgroup_barriers = true; + } + } + } + + need_additional_barriers &= has_any_workgroup_barriers; + } + + /* Extract the full control flow of the shader. */ + nir_cf_list extracted; + nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body)); + b->cursor = nir_before_cf_list(&impl->body); + + /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */ + nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); + nir_ssa_def *zero = nir_imm_int(b, 0); + + if (need_additional_barriers) { + /* First invocation stores 0 to number of API waves in flight. */ + nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); + { + nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr); + } + nir_pop_if(b, if_first_in_workgroup); + + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_shader_out | nir_var_mem_shared); + } + + nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, api_workgroup_size)); + nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation); + { + nir_cf_reinsert(&extracted, b->cursor); + b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list); + + if (need_additional_barriers) { + /* One invocation in each API wave decrements the number of API waves in flight. */ + nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1)); + { + nir_shared_atomic_add(b, 32, zero, nir_imm_int(b, -1u), .base = api_waves_in_flight_addr); + } + nir_pop_if(b, if_elected_again); + + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_shader_out | nir_var_mem_shared); + } + } + nir_pop_if(b, if_has_api_ms_invocation); + + if (need_additional_barriers) { + /* Make sure that waves that don't run any API invocations execute + * the same amount of barriers as those that do. + * + * We do this by executing a barrier until the number of API waves + * in flight becomes zero. + */ + nir_ssa_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation); + nir_ssa_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0); + nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms); + { + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + nir_loop *loop = nir_push_loop(b); + { + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_shader_out | nir_var_mem_shared); + + nir_ssa_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr); + nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, if_break); + } + nir_pop_loop(b, loop); + } + nir_pop_if(b, if_elected); + } + nir_pop_if(b, if_wave_has_no_api_ms); + } +} + void ac_nir_lower_ngg_ms(nir_shader *shader, unsigned wave_size) { - nir_function_impl *impl = nir_shader_get_entrypoint(shader); - assert(impl); - unsigned vertices_per_prim = num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type); @@ -2436,9 +2576,14 @@ ac_nir_lower_ngg_ms(nir_shader *shader, unsigned max_vertices = shader->info.mesh.max_vertices_out; unsigned max_primitives = shader->info.mesh.max_primitives_out; - /* LDS area for total number of output primitives. */ + /* LDS area for total number of output primitives and other info. + * DW0: number of primitives + * DW1: reserved for later use + * DW2: reserved for later use + * DW3: number of API workgroups in flight + */ unsigned numprims_lds_addr = ALIGN(shader->info.shared_size, 16); - unsigned numprims_lds_size = 4; + unsigned numprims_lds_size = 16; /* LDS area for vertex attributes */ unsigned vertex_attr_lds_addr = ALIGN(numprims_lds_addr + numprims_lds_size, 16); unsigned vertex_attr_lds_size = max_vertices * num_per_vertex_outputs * 16; @@ -2464,29 +2609,31 @@ ac_nir_lower_ngg_ms(nir_shader *shader, .numprims_lds_addr = numprims_lds_addr, }; - /* Extract the full control flow. It is going to be wrapped in an if statement. */ - nir_cf_list extracted; - nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body)); + /* The workgroup size that is specified by the API shader may be different + * from the size of the workgroup that actually runs on the HW, due to the + * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed. + * + * Therefore, we must make sure that when the API workgroup size is smaller, + * we don't run the API shader on more HW invocations than is necessary. + */ + unsigned api_workgroup_size = shader->info.workgroup_size[0] * + shader->info.workgroup_size[1] * + shader->info.workgroup_size[2]; + + unsigned hw_workgroup_size = + ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size); + + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + assert(impl); nir_builder builder; nir_builder *b = &builder; /* This is to avoid the & */ nir_builder_init(b, impl); b->cursor = nir_before_cf_list(&impl->body); - /* There may be a difference between MS workgroup size and the - * number of output vertices/primitives. So it is possible that the actual H - * workgroup is larger than what the user wants. - * So, only execute the API shader for invocations that the user needs. - */ - unsigned num_ms_invocations = b->shader->info.workgroup_size[0] * - b->shader->info.workgroup_size[1] * - b->shader->info.workgroup_size[2]; - nir_ssa_def *invocation_index = nir_load_local_invocation_index(b); - nir_ssa_def *has_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, num_ms_invocations)); - nir_if *if_has_ms_invocation = nir_push_if(b, has_ms_invocation); - nir_cf_reinsert(&extracted, b->cursor); - b->cursor = nir_after_cf_list(&if_has_ms_invocation->then_list); - nir_pop_if(b, if_has_ms_invocation); + if (api_workgroup_size < hw_workgroup_size) { + handle_smaller_ms_api_workgroup(impl, b, api_workgroup_size, hw_workgroup_size, &state); + } lower_ms_intrinsics(shader, &state); emit_ms_finale(shader, &state);