diff --git a/src/amd/compiler/aco_insert_waitcnt.cpp b/src/amd/compiler/aco_insert_waitcnt.cpp index 90d27c5be43..2d6cba63c09 100644 --- a/src/amd/compiler/aco_insert_waitcnt.cpp +++ b/src/amd/compiler/aco_insert_waitcnt.cpp @@ -168,7 +168,9 @@ private: enum barrier_info_kind { /* Waits for all non-private accesses and all scratch/vgpr-spill accesses */ - barrier_info_all, + barrier_info_release_dep, + /* Waits for all atomics */ + barrier_info_acquire_dep, num_barrier_infos, }; @@ -427,28 +429,29 @@ check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr) } void -perform_barrier(wait_ctx& ctx, wait_imm& imm, memory_sync_info sync, unsigned semantics) +perform_barrier(wait_ctx& ctx, wait_imm& imm, memory_sync_info sync, bool is_acquire) { sync_scope subgroup_scope = ctx.program->workgroup_size <= ctx.program->wave_size ? scope_workgroup : scope_subgroup; - if ((sync.semantics & semantics) && sync.scope > subgroup_scope) { - barrier_info& bar = ctx.bar[barrier_info_all]; + if (sync.scope <= subgroup_scope) + return; - u_foreach_bit (i, sync.storage & bar.storage) { - uint16_t events = bar.events[i]; + barrier_info& bar = ctx.bar[is_acquire ? barrier_info_acquire_dep : barrier_info_release_dep]; - /* LDS is private to the workgroup */ - if (MIN2(sync.scope, scope_workgroup) <= subgroup_scope) - events &= ~event_lds; + u_foreach_bit (i, sync.storage & bar.storage) { + uint16_t events = bar.events[i]; - /* Until GFX11, in non-WGP, the L1 (L0 on GFX10+) cache keeps all memory operations - * in-order for the same workgroup */ - if (ctx.gfx_level < GFX11 && !ctx.program->wgp_mode && sync.scope <= scope_workgroup) - events &= ~(event_vmem | event_vmem_store); + /* LDS is private to the workgroup */ + if (MIN2(sync.scope, scope_workgroup) <= subgroup_scope) + events &= ~event_lds; - if (events) - imm.combine(bar.imm[i]); - } + /* Until GFX11, in non-WGP, the L1 (L0 on GFX10+) cache keeps all memory operations + * in-order for the same workgroup */ + if (ctx.gfx_level < GFX11 && !ctx.program->wgp_mode && sync.scope <= scope_workgroup) + events &= ~(event_vmem | event_vmem_store | event_smem); + + if (events) + imm.combine(bar.imm[i]); } } @@ -500,7 +503,7 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf */ if (ctx.gfx_level >= GFX11 && instr->opcode == aco_opcode::s_sendmsg && instr->salu().imm == sendmsg_dealloc_vgprs) { - barrier_info& bar = ctx.bar[barrier_info_all]; + barrier_info& bar = ctx.bar[barrier_info_release_dep]; imm.combine(bar.imm[ffs(storage_scratch) - 1]); imm.combine(bar.imm[ffs(storage_vgpr_spill) - 1]); } @@ -549,14 +552,18 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf if (instr->opcode == aco_opcode::ds_ordered_count && ((instr->ds().offset1 | (instr->ds().offset0 >> 8)) & 0x1)) { - barrier_info& bar = ctx.bar[barrier_info_all]; + barrier_info& bar = ctx.bar[barrier_info_release_dep]; imm.combine(bar.imm[ffs(storage_gds) - 1]); } - if (instr->opcode == aco_opcode::p_barrier) - perform_barrier(ctx, imm, instr->barrier().sync, semantic_acqrel); - else - perform_barrier(ctx, imm, sync_info, semantic_release); + if (instr->opcode == aco_opcode::p_barrier) { + if (instr->barrier().sync.semantics & semantic_release) + perform_barrier(ctx, imm, instr->barrier().sync, false); + if (instr->barrier().sync.semantics & semantic_acquire) + perform_barrier(ctx, imm, instr->barrier().sync, true); + } else if (sync_info.semantics & semantic_release) { + perform_barrier(ctx, imm, sync_info, false); + } if (!imm.empty()) { if (ctx.pending_flat_vm && imm.vm != wait_imm::unset_counter) @@ -626,24 +633,31 @@ update_barrier_info_for_event(wait_ctx& ctx, uint8_t counters, wait_event event, /* This resets or increases the counters for the barrier infos in response to an instruction. */ void -update_barriers(wait_ctx& ctx, uint8_t counters, wait_event event, memory_sync_info sync) +update_barriers(wait_ctx& ctx, uint8_t counters, wait_event event, Instruction* instr, + memory_sync_info sync) { - uint16_t storage_all = sync.storage; - /* We re-use barrier_info_all to wait for all scratch stores to finish, to track those even if - * they are private. */ + uint16_t storage_rel = sync.storage; + /* We re-use barrier_info_release_dep to wait for all scratch stores to finish, so track those + * even if they are private. */ if (sync.semantics & semantic_private) - storage_all &= storage_scratch | storage_vgpr_spill; - update_barrier_info_for_event(ctx, counters, event, barrier_info_all, storage_all); + storage_rel &= storage_scratch | storage_vgpr_spill; + update_barrier_info_for_event(ctx, counters, event, barrier_info_release_dep, storage_rel); + + if (instr) { + uint16_t storage_acq = is_atomic_or_control_instr(ctx.program, instr, sync, semantic_acquire); + update_barrier_info_for_event(ctx, counters, event, barrier_info_acquire_dep, storage_acq); + } } void -update_counters(wait_ctx& ctx, wait_event event, memory_sync_info sync = memory_sync_info()) +update_counters(wait_ctx& ctx, wait_event event, Instruction* instr, + memory_sync_info sync = memory_sync_info()) { uint8_t counters = ctx.info->get_counters_for_event(event); ctx.nonzero |= counters; - update_barriers(ctx, counters, event, sync); + update_barriers(ctx, counters, event, instr, sync); if (ctx.info->unordered_events & event) return; @@ -715,7 +729,7 @@ gen(Instruction* instr, wait_ctx& ctx) ev = event_exp_pos; else ev = event_exp_param; - update_counters(ctx, ev); + update_counters(ctx, ev, instr); /* insert new entries for exported vgprs */ for (unsigned i = 0; i < 4; i++) { @@ -731,8 +745,8 @@ gen(Instruction* instr, wait_ctx& ctx) case Format::FLAT: { FLAT_instruction& flat = instr->flat(); wait_event vmem_ev = get_vmem_event(ctx, instr, vmem_nosampler); - update_counters(ctx, vmem_ev, flat.sync); - update_counters(ctx, event_lds, flat.sync); + update_counters(ctx, vmem_ev, instr, flat.sync); + update_counters(ctx, event_lds, instr, flat.sync); if (!instr->definitions.empty()) insert_wait_entry(ctx, instr->definitions[0], vmem_ev, 0, get_vmem_mask(ctx, instr)); @@ -747,7 +761,7 @@ gen(Instruction* instr, wait_ctx& ctx) } case Format::SMEM: { SMEM_instruction& smem = instr->smem(); - update_counters(ctx, event_smem, smem.sync); + update_counters(ctx, event_smem, instr, smem.sync); if (!instr->definitions.empty()) insert_wait_entry(ctx, instr->definitions[0], event_smem); @@ -758,9 +772,9 @@ gen(Instruction* instr, wait_ctx& ctx) } case Format::DS: { DS_instruction& ds = instr->ds(); - update_counters(ctx, ds.gds ? event_gds : event_lds, ds.sync); + update_counters(ctx, ds.gds ? event_gds : event_lds, instr, ds.sync); if (ds.gds) - update_counters(ctx, event_gds_gpr_lock); + update_counters(ctx, event_gds_gpr_lock, instr); for (auto& definition : instr->definitions) insert_wait_entry(ctx, definition, ds.gds ? event_gds : event_lds); @@ -774,7 +788,7 @@ gen(Instruction* instr, wait_ctx& ctx) } case Format::LDSDIR: { LDSDIR_instruction& ldsdir = instr->ldsdir(); - update_counters(ctx, event_ldsdir, ldsdir.sync); + update_counters(ctx, event_ldsdir, instr, ldsdir.sync); insert_wait_entry(ctx, instr->definitions[0], event_ldsdir); break; } @@ -787,16 +801,16 @@ gen(Instruction* instr, wait_ctx& ctx) wait_event ev = get_vmem_event(ctx, instr, type); uint32_t mask = ev == event_vmem ? get_vmem_mask(ctx, instr) : 0; - update_counters(ctx, ev, get_sync_info(instr)); + update_counters(ctx, ev, instr, get_sync_info(instr)); for (auto& definition : instr->definitions) insert_wait_entry(ctx, definition, ev, type, mask); if (ctx.gfx_level == GFX6 && instr->format != Format::MIMG && instr->operands.size() == 4) { - update_counters(ctx, event_vmem_gpr_lock); + update_counters(ctx, event_vmem_gpr_lock, instr); insert_wait_entry(ctx, instr->operands[3], event_vmem_gpr_lock); } else if (ctx.gfx_level == GFX6 && instr->isMIMG() && !instr->operands[2].isUndefined()) { - update_counters(ctx, event_vmem_gpr_lock); + update_counters(ctx, event_vmem_gpr_lock, instr); insert_wait_entry(ctx, instr->operands[2], event_vmem_gpr_lock); } @@ -804,13 +818,13 @@ gen(Instruction* instr, wait_ctx& ctx) } case Format::SOPP: { if (instr->opcode == aco_opcode::s_sendmsg || instr->opcode == aco_opcode::s_sendmsghalt) - update_counters(ctx, event_sendmsg); + update_counters(ctx, event_sendmsg, instr); break; } case Format::SOP1: { if (instr->opcode == aco_opcode::s_sendmsg_rtn_b32 || instr->opcode == aco_opcode::s_sendmsg_rtn_b64) { - update_counters(ctx, event_sendmsg); + update_counters(ctx, event_sendmsg, instr); insert_wait_entry(ctx, instr->definitions[0], event_sendmsg); } break; @@ -900,10 +914,11 @@ handle_block(Program* program, Block& block, wait_ctx& ctx) !((instr->ds().offset1 | (instr->ds().offset0 >> 8)) & 0x1); new_instructions.emplace_back(std::move(instr)); - perform_barrier(ctx, queued_imm, sync_info, semantic_acquire); + if (sync_info.semantics & semantic_acquire) + perform_barrier(ctx, queued_imm, sync_info, true); if (is_ordered_count_acquire) - queued_imm.combine(ctx.bar[barrier_info_all].imm[ffs(storage_gds) - 1]); + queued_imm.combine(ctx.bar[barrier_info_release_dep].imm[ffs(storage_gds) - 1]); } } @@ -935,12 +950,12 @@ insert_waitcnt(Program* program) unsigned loop_progress = 0; if (program->pending_lds_access) { - update_barriers(in_ctx[0], info.get_counters_for_event(event_lds), event_lds, + update_barriers(in_ctx[0], info.get_counters_for_event(event_lds), event_lds, NULL, memory_sync_info(storage_shared)); } for (Definition def : program->args_pending_vmem) { - update_counters(in_ctx[0], event_vmem); + update_counters(in_ctx[0], event_vmem, NULL); insert_wait_entry(in_ctx[0], def, event_vmem, vmem_nosampler, 0xffffffff); }