diff --git a/src/amd/compiler/aco_insert_waitcnt.cpp b/src/amd/compiler/aco_insert_waitcnt.cpp index bb3e362a0a6..c3ff61a6f27 100644 --- a/src/amd/compiler/aco_insert_waitcnt.cpp +++ b/src/amd/compiler/aco_insert_waitcnt.cpp @@ -69,7 +69,10 @@ enum wait_event : uint16_t { event_vmem_gpr_lock = 1 << 10, event_sendmsg = 1 << 11, event_ldsdir = 1 << 12, - num_events = 13, + event_valu = 1 << 13, + event_trans = 1 << 14, + event_salu = 1 << 15, + num_events = 16, }; enum counter_type : uint8_t { @@ -77,7 +80,8 @@ enum counter_type : uint8_t { counter_lgkm = 1 << 1, counter_vm = 1 << 2, counter_vs = 1 << 3, - num_counters = 4, + counter_alu = 1 << 4, + num_counters = 5, }; enum vmem_type : uint8_t { @@ -93,6 +97,91 @@ static const uint16_t lgkm_events = event_smem | event_lds | event_gds | event_f static const uint16_t vm_events = event_vmem | event_flat; static const uint16_t vs_events = event_vmem_store; +/* On GFX11+ the SIMD frontend doesn't switch to issuing instructions from a different + * wave if there is an ALU stall. Hence we have an instruction (s_delay_alu) to signal + * that we should switch to a different wave and contains info on dependencies as to + * when we can switch back. + * + * This seems to apply only for ALU->ALU dependencies as other instructions have better + * integration with the frontend. + * + * Note that if we do not emit s_delay_alu things will still be correct, but the wave + * will stall in the ALU (and the ALU will be doing nothing else). We'll use this as + * I'm pretty sure our cycle info is wrong at times (necessarily so, e.g. wave64 VALU + * instructions can take a different number of cycles based on the exec mask) + */ +struct alu_delay_info { + /* These are the values directly above the max representable value, i.e. the wait + * would turn into a no-op when we try to wait for something further back than + * this. + */ + static constexpr int8_t valu_nop = 5; + static constexpr int8_t trans_nop = 4; + + /* How many VALU instructions ago this value was written */ + int8_t valu_instrs = valu_nop; + /* Cycles until the writing VALU instruction is finished */ + int8_t valu_cycles = 0; + + /* How many Transcedent instructions ago this value was written */ + int8_t trans_instrs = trans_nop; + /* Cycles until the writing Transcendent instruction is finished */ + int8_t trans_cycles = 0; + + /* Cycles until the writing SALU instruction is finished*/ + int8_t salu_cycles = 0; + + bool combine(const alu_delay_info& other) + { + bool changed = other.valu_instrs < valu_instrs || other.trans_instrs < trans_instrs || + other.salu_cycles > salu_cycles || other.valu_cycles > valu_cycles || + other.trans_cycles > trans_cycles; + valu_instrs = std::min(valu_instrs, other.valu_instrs); + trans_instrs = std::min(trans_instrs, other.trans_instrs); + salu_cycles = std::max(salu_cycles, other.salu_cycles); + valu_cycles = std::max(valu_cycles, other.valu_cycles); + trans_cycles = std::max(trans_cycles, other.trans_cycles); + return changed; + } + + /* Needs to be called after any change to keep the data consistent. */ + void fixup() + { + if (valu_instrs >= valu_nop || valu_cycles <= 0) { + valu_instrs = valu_nop; + valu_cycles = 0; + } + + if (trans_instrs >= trans_nop || trans_cycles <= 0) { + trans_instrs = trans_nop; + trans_cycles = 0; + } + + salu_cycles = std::max(salu_cycles, 0); + } + + /* Returns true if a wait would be a no-op */ + bool empty() const + { + return valu_instrs == valu_nop && trans_instrs == trans_nop && salu_cycles == 0; + } +}; + +enum class alu_delay_wait { + NO_DEP, + VALU_DEP_1, + VALU_DEP_2, + VALU_DEP_3, + VALU_DEP_4, + TRANS32_DEP_1, + TRANS32_DEP_2, + TRANS32_DEP_3, + FMA_ACCUM_CYCLE_1, + SALU_CYCLE_1, + SALU_CYCLE_2, + SALU_CYCLE_3 +}; + uint8_t get_counters_for_event(wait_event ev) { @@ -110,20 +199,25 @@ get_counters_for_event(wait_event ev) case event_gds_gpr_lock: case event_vmem_gpr_lock: case event_ldsdir: return counter_exp; + case event_valu: + case event_trans: + case event_salu: return counter_alu; default: return 0; } } struct wait_entry { wait_imm imm; + alu_delay_info delay; uint16_t events; /* use wait_event notion */ uint8_t counters; /* use counter_type notion */ bool wait_on_read : 1; bool logical : 1; uint8_t vmem_types : 4; - wait_entry(wait_event event_, wait_imm imm_, bool logical_, bool wait_on_read_) - : imm(imm_), events(event_), counters(get_counters_for_event(event_)), + wait_entry(wait_event event_, wait_imm imm_, alu_delay_info delay_, bool logical_, + bool wait_on_read_) + : imm(imm_), delay(delay_), events(event_), counters(get_counters_for_event(event_)), wait_on_read(wait_on_read_), logical(logical_), vmem_types(0) {} @@ -134,6 +228,7 @@ struct wait_entry { events |= other.events; counters |= other.counters; changed |= imm.combine(other.imm); + changed |= delay.combine(other.delay); wait_on_read |= other.wait_on_read; vmem_types |= other.vmem_types; assert(logical == other.logical); @@ -167,6 +262,11 @@ struct wait_entry { if (!(counters & counter_lgkm) && !(counters & counter_vm)) events &= ~event_flat; + + if (counter == counter_alu) { + delay = alu_delay_info(); + events &= ~(event_valu | event_trans | event_salu); + } } }; @@ -258,7 +358,7 @@ get_vmem_type(Instruction* instr) } void -check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr) +check_instr(wait_ctx& ctx, wait_imm& wait, alu_delay_info& delay, Instruction* instr) { for (const Operand op : instr->operands) { if (op.isConstant() || op.isUndefined()) @@ -272,6 +372,8 @@ check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr) continue; wait.combine(it->second.imm); + if (instr->isVALU() || instr->isSALU() || instr->isVINTERP_INREG()) + delay.combine(it->second.delay); } } @@ -314,6 +416,25 @@ parse_wait_instr(wait_ctx& ctx, wait_imm& imm, Instruction* instr) return false; } +bool +parse_delay_alu(wait_ctx& ctx, alu_delay_info& delay, Instruction* instr) +{ + if (instr->opcode != aco_opcode::s_delay_alu) + return false; + + unsigned imm[2] = {instr->sopp().imm & 0xf, (instr->sopp().imm >> 7) & 0xf}; + for (unsigned i = 0; i < 2; ++i) { + alu_delay_wait wait = (alu_delay_wait)imm[i]; + if (wait >= alu_delay_wait::VALU_DEP_1 && wait <= alu_delay_wait::VALU_DEP_4) + delay.valu_instrs = imm[i] - (uint32_t)alu_delay_wait::VALU_DEP_1 + 1; + else if (wait >= alu_delay_wait::TRANS32_DEP_1 && wait <= alu_delay_wait::TRANS32_DEP_3) + delay.trans_instrs = imm[i] - (uint32_t)alu_delay_wait::TRANS32_DEP_1 + 1; + else if (wait >= alu_delay_wait::SALU_CYCLE_1) + delay.salu_cycles = imm[i] - (uint32_t)alu_delay_wait::SALU_CYCLE_1 + 1; + } + return true; +} + void perform_barrier(wait_ctx& ctx, wait_imm& imm, memory_sync_info sync, unsigned semantics) { @@ -359,7 +480,28 @@ force_waitcnt(wait_ctx& ctx, wait_imm& imm) } void -kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_info) +update_alu(wait_ctx& ctx, bool is_valu, bool is_trans, bool clear, int cycles) +{ + for (std::pair& e : ctx.gpr_map) { + wait_entry& entry = e.second; + + if (clear) { + entry.delay = alu_delay_info(); + } else { + entry.delay.valu_instrs += is_valu ? 1 : 0; + entry.delay.trans_instrs += is_trans ? 1 : 0; + entry.delay.salu_cycles -= cycles; + entry.delay.valu_cycles -= cycles; + entry.delay.trans_cycles -= cycles; + + entry.delay.fixup(); + } + } +} + +void +kill(wait_imm& imm, alu_delay_info& delay, Instruction* instr, wait_ctx& ctx, + memory_sync_info sync_info) { if (instr->opcode == aco_opcode::s_setpc_b64 || (debug_flags & DEBUG_FORCE_WAITCNT)) { /* Force emitting waitcnt states right after the instruction if there is @@ -369,8 +511,7 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf force_waitcnt(ctx, imm); } - if (ctx.exp_cnt || ctx.vm_cnt || ctx.lgkm_cnt) - check_instr(ctx, imm, instr); + check_instr(ctx, imm, delay, instr); /* It's required to wait for scalar stores before "writing back" data. * It shouldn't cost anything anyways since we're about to do s_endpgm. @@ -418,7 +559,7 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf else perform_barrier(ctx, imm, sync_info, semantic_release); - if (!imm.empty()) { + if (!imm.empty() || !delay.empty()) { if (ctx.pending_flat_vm && imm.vm != wait_imm::unset_counter) imm.vm = 0; if (ctx.pending_flat_lgkm && imm.lgkm != wait_imm::unset_counter) @@ -454,6 +595,10 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf bar_ev &= ~event_flat; } + if (ctx.program->gfx_level >= GFX11) { + update_alu(ctx, false, false, false, MAX3(delay.salu_cycles, delay.valu_cycles, delay.trans_cycles)); + } + /* remove all gprs with higher counter from map */ std::map::iterator it = ctx.gpr_map.begin(); while (it != ctx.gpr_map.end()) { @@ -465,6 +610,13 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf ctx.wait_and_remove_from_entry(it->first, it->second, counter_lgkm); if (imm.vs != wait_imm::unset_counter && imm.vs <= it->second.imm.vs) ctx.wait_and_remove_from_entry(it->first, it->second, counter_vs); + if (delay.valu_instrs <= it->second.delay.valu_instrs) + it->second.delay.valu_instrs = alu_delay_info::valu_nop; + if (delay.trans_instrs <= it->second.delay.trans_instrs) + it->second.delay.trans_instrs = alu_delay_info::trans_nop; + it->second.delay.fixup(); + if (it->second.delay.empty()) + ctx.wait_and_remove_from_entry(it->first, it->second, counter_alu); if (!it->second.counters) it = ctx.gpr_map.erase(it); else @@ -587,7 +739,7 @@ update_counters_for_flat_load(wait_ctx& ctx, memory_sync_info sync = memory_sync void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read, - uint8_t vmem_types = 0) + uint8_t vmem_types = 0, unsigned cycles = 0) { uint16_t counters = get_counters_for_event(event); wait_imm imm; @@ -600,7 +752,18 @@ insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, boo if (counters & counter_vs) imm.vs = 0; - wait_entry new_entry(event, imm, !rc.is_linear(), wait_on_read); + alu_delay_info delay; + if (event == event_valu) { + delay.valu_instrs = 0; + delay.valu_cycles = cycles; + } else if (event == event_trans) { + delay.trans_instrs = 0; + delay.trans_cycles = cycles; + } else if (event == event_salu) { + delay.salu_cycles = cycles; + } + + wait_entry new_entry(event, imm, delay, !rc.is_linear(), wait_on_read); new_entry.vmem_types |= vmem_types; for (unsigned i = 0; i < rc.size(); i++) { @@ -614,13 +777,38 @@ void insert_wait_entry(wait_ctx& ctx, Operand op, wait_event event, uint8_t vmem_types = 0) { if (!op.isConstant() && !op.isUndefined()) - insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false, vmem_types); + insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false, vmem_types, 0); } void -insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, uint8_t vmem_types = 0) +insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, uint8_t vmem_types = 0, + unsigned cycles = 0) { - insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, vmem_types); + insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, vmem_types, cycles); +} + +void +gen_alu(Instruction* instr, wait_ctx& ctx) +{ + Instruction_cycle_info cycle_info = get_cycle_info(*ctx.program, *instr); + bool is_valu = instr->isVALU() || instr->isVINTERP_INREG(); + bool is_trans = instr->isTrans(); + bool clear = instr->isEXP() || instr->isDS() || instr->isMIMG() || instr->isFlatLike() || + instr->isMUBUF() || instr->isMTBUF(); + + wait_event event = (wait_event)0; + if (is_trans) + event = event_trans; + else if (is_valu) + event = event_valu; + else if (instr->isSALU()) + event = event_salu; + + if (event != (wait_event)0) { + for (const Definition& def : instr->definitions) + insert_wait_entry(ctx, def, event, 0, cycle_info.latency); + } + update_alu(ctx, is_valu, is_trans, clear, cycle_info.issue_cycles); } void @@ -755,22 +943,55 @@ emit_waitcnt(wait_ctx& ctx, std::vector>& instructions, wai imm = wait_imm(); } +void +emit_delay_alu(wait_ctx& ctx, std::vector>& instructions, + alu_delay_info& delay) +{ + uint32_t imm = 0; + if (delay.trans_instrs != delay.trans_nop) { + imm |= (uint32_t)alu_delay_wait::TRANS32_DEP_1 + delay.trans_instrs - 1; + } + + if (delay.valu_instrs != delay.valu_nop) { + imm |= ((uint32_t)alu_delay_wait::VALU_DEP_1 + delay.valu_instrs - 1) << (imm ? 7 : 0); + } + + /* Note that we can only put 2 wait conditions in the instruction, so if we have all 3 we just + * drop the SALU one. Here we use that this doesn't really affect correctness so occasionally + * getting this wrong isn't an issue. */ + if (delay.salu_cycles && imm <= 0xf) { + unsigned cycles = std::min(3, delay.salu_cycles); + imm |= ((uint32_t)alu_delay_wait::SALU_CYCLE_1 + cycles - 1) << (imm ? 7 : 0); + } + + SOPP_instruction* inst = + create_instruction(aco_opcode::s_delay_alu, Format::SOPP, 0, 0); + inst->imm = imm; + inst->block = -1; + instructions.emplace_back(inst); + delay = alu_delay_info(); +} + void handle_block(Program* program, Block& block, wait_ctx& ctx) { std::vector> new_instructions; wait_imm queued_imm; + alu_delay_info queued_delay; for (aco_ptr& instr : block.instructions) { bool is_wait = parse_wait_instr(ctx, queued_imm, instr.get()); + bool is_delay_alu = parse_delay_alu(ctx, queued_delay, instr.get()); memory_sync_info sync_info = get_sync_info(instr.get()); - kill(queued_imm, instr.get(), ctx, sync_info); + kill(queued_imm, queued_delay, instr.get(), ctx, sync_info); gen(instr.get(), ctx); + if (program->gfx_level >= GFX11) + gen_alu(instr.get(), ctx); - if (instr->format != Format::PSEUDO_BARRIER && !is_wait) { + if (instr->format != Format::PSEUDO_BARRIER && !is_wait && !is_delay_alu) { if (instr->isVINTERP_INREG() && queued_imm.exp != wait_imm::unset_counter) { instr->vinterp_inreg().wait_exp = MIN2(instr->vinterp_inreg().wait_exp, queued_imm.exp); queued_imm.exp = wait_imm::unset_counter; @@ -778,6 +999,8 @@ handle_block(Program* program, Block& block, wait_ctx& ctx) if (!queued_imm.empty()) emit_waitcnt(ctx, new_instructions, queued_imm); + if (!queued_delay.empty()) + emit_delay_alu(ctx, new_instructions, queued_delay); bool is_ordered_count_acquire = instr->opcode == aco_opcode::ds_ordered_count && @@ -793,6 +1016,8 @@ handle_block(Program* program, Block& block, wait_ctx& ctx) if (!queued_imm.empty()) emit_waitcnt(ctx, new_instructions, queued_imm); + if (!queued_delay.empty()) + emit_delay_alu(ctx, new_instructions, queued_delay); block.instructions.swap(new_instructions); }