From 6d799ac283de1dc54d86564b11aee541d0f278e3 Mon Sep 17 00:00:00 2001 From: Natalie Vock Date: Sat, 6 Sep 2025 14:45:28 +0200 Subject: [PATCH] aco: Add pass for spilling call-related registers This is a post-RA pass that tracks registers that are preserved by the ABI, but clobbered by shader code. The pass inserts scratch spills and reloads in appropriate locations to ensure the register values at the end of the shader are the same as they were at the start. Part-of: --- src/amd/compiler/aco_builder_h.py | 2 +- src/amd/compiler/aco_interface.cpp | 2 + src/amd/compiler/aco_ir.h | 7 + src/amd/compiler/aco_spill_preserved.cpp | 671 +++++++++++++++++++++++ src/amd/compiler/meson.build | 1 + 5 files changed, 682 insertions(+), 1 deletion(-) create mode 100644 src/amd/compiler/aco_spill_preserved.cpp diff --git a/src/amd/compiler/aco_builder_h.py b/src/amd/compiler/aco_builder_h.py index dd2e1f80c2c..c0a01dddecf 100644 --- a/src/amd/compiler/aco_builder_h.py +++ b/src/amd/compiler/aco_builder_h.py @@ -563,7 +563,7 @@ public: <% import itertools formats = [("pseudo", [Format.PSEUDO], list(itertools.product(range(5), range(7))) + [(8, 1), (1, 8), (1, 7)]), - ("sop1", [Format.SOP1], [(0, 1), (1, 0), (1, 1), (2, 1), (3, 2)]), + ("sop1", [Format.SOP1], [(0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (3, 2)]), ("sop2", [Format.SOP2], itertools.product([1, 2], [2, 3])), ("sopk", [Format.SOPK], itertools.product([0, 1, 2], [0, 1])), ("sopp", [Format.SOPP], [(0, 0), (0, 1)]), diff --git a/src/amd/compiler/aco_interface.cpp b/src/amd/compiler/aco_interface.cpp index f7cb834acf2..3e71d3bae0e 100644 --- a/src/amd/compiler/aco_interface.cpp +++ b/src/amd/compiler/aco_interface.cpp @@ -144,6 +144,8 @@ aco_postprocess_shader(const struct aco_compiler_options* options, validate(program.get()); } + spill_preserved(program.get()); + /* Lower to HW Instructions */ ssa_elimination(program.get()); lower_to_hw_instr(program.get()); diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index 4434e74249c..c10e82c6592 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -2431,6 +2431,7 @@ void setup_reduce_temp(Program* program); void lower_to_cssa(Program* program); void register_allocation(Program* program, ra_test_policy = {}); void reindex_ssa(Program* program); +void spill_preserved(Program* program); void ssa_elimination(Program* program); void lower_to_hw_instr(Program* program); void schedule_program(Program* program); @@ -2593,4 +2594,10 @@ extern const Info instr_info; } // namespace aco +namespace std { +template <> struct hash { + size_t operator()(aco::PhysReg reg) const noexcept { return std::hash{}(reg.reg_b); } +}; +} // namespace std + #endif /* ACO_IR_H */ diff --git a/src/amd/compiler/aco_spill_preserved.cpp b/src/amd/compiler/aco_spill_preserved.cpp new file mode 100644 index 00000000000..77421e77933 --- /dev/null +++ b/src/amd/compiler/aco_spill_preserved.cpp @@ -0,0 +1,671 @@ +/* + * Copyright © 2024 Valve Corporation + * + * SPDX-License-Identifier: MIT + */ + +#include "aco_builder.h" +#include "aco_ir.h" + +#include +#include +#include + +namespace aco { + +struct postdom_info { + unsigned logical_imm_postdom; + unsigned linear_imm_postdom; +}; + +struct spill_preserved_ctx { + Program* program; + BITSET_DECLARE(abi_preserved_regs, 512); + aco::monotonic_buffer_resource memory; + + /* Info on how to spill preserved VGPRs. */ + aco::unordered_map preserved_spill_offsets; + aco::unordered_set preserved_vgprs; + aco::unordered_set preserved_linear_vgprs; + /* Info on how to spill preserved SGPRs. */ + aco::unordered_map preserved_spill_lanes; + aco::unordered_set preserved_sgprs; + + aco::unordered_map> reg_block_uses; + std::vector dom_info; + + /* The start of the register range dedicated to spilling preserved SGPRs. */ + aco::unordered_set sgpr_spill_regs; + + /* Next scratch offset to spill VGPRs to. */ + unsigned next_preserved_offset; + /* Next linear VGPR lane to spill SGPRs to. */ + unsigned next_preserved_lane; + + explicit spill_preserved_ctx(Program* program_) + : program(program_), memory(), preserved_spill_offsets(memory), preserved_vgprs(memory), + preserved_linear_vgprs(memory), preserved_spill_lanes(memory), preserved_sgprs(memory), + reg_block_uses(memory), sgpr_spill_regs(memory), + next_preserved_offset( + DIV_ROUND_UP(program_->config->scratch_bytes_per_wave, program_->wave_size)), + next_preserved_lane(0) + { + program->callee_abi.preservedRegisters(abi_preserved_regs); + dom_info.resize(program->blocks.size(), {-1u, -1u}); + } +}; + +bool +can_reload_at_instr(const aco_ptr& instr) +{ + return instr->opcode == aco_opcode::p_reload_preserved || instr->opcode == aco_opcode::p_return; +} + +void +add_instr(spill_preserved_ctx& ctx, unsigned block_index, bool seen_reload, + const aco_ptr& instr, Instruction* startpgm) +{ + for (auto& def : instr->definitions) { + assert(def.isFixed()); + /* Round down subdword registers to their base */ + PhysReg start_reg = PhysReg{def.physReg().reg()}; + for (PhysReg reg = start_reg; reg < start_reg.advance(def.bytes()); reg = reg.advance(4)) { + if (!BITSET_TEST(ctx.abi_preserved_regs, reg.reg()) && !def.regClass().is_linear_vgpr()) + continue; + + if (instr->opcode == aco_opcode::p_start_linear_vgpr) { + /* Don't count start_linear_vgpr without a copy as a use since the value doesn't matter. + * This allows us to move reloads a bit further up the CF. + */ + if (instr->operands.empty()) + continue; + } + + if (def.regClass().is_linear_vgpr()) + ctx.preserved_linear_vgprs.insert(reg); + else if (def.regClass().type() == RegType::sgpr) + ctx.preserved_sgprs.insert(reg); + else + ctx.preserved_vgprs.insert(reg); + + if (seen_reload) { + if (def.regClass().is_linear()) + for (auto succ : ctx.program->blocks[block_index].linear_succs) + ctx.reg_block_uses[reg].emplace(succ); + else + for (auto succ : ctx.program->blocks[block_index].logical_succs) + ctx.reg_block_uses[reg].emplace(succ); + } else { + ctx.reg_block_uses[reg].emplace(block_index); + } + } + } + + for (auto& op : instr->operands) { + assert(op.isFixed()); + + if (!op.isTemp()) + continue; + /* Temporaries defined by startpgm are the preserved value - these uses don't need + * any preservation. + */ + if (std::any_of(startpgm->definitions.begin(), startpgm->definitions.end(), + [op](const auto& def) + { return def.isTemp() && def.tempId() == op.tempId(); })) + continue; + + /* Round down subdword registers to their base */ + PhysReg start_reg = PhysReg{op.physReg().reg()}; + for (PhysReg reg = start_reg; reg < start_reg.advance(op.bytes()); reg = reg.advance(4)) { + if (instr->opcode == aco_opcode::p_spill && &op == &instr->operands[0]) { + assert(op.regClass().is_linear_vgpr()); + ctx.preserved_linear_vgprs.insert(reg); + } + + if (seen_reload) { + if (op.regClass().is_linear()) + for (auto succ : ctx.program->blocks[block_index].linear_succs) + ctx.reg_block_uses[reg].emplace(succ); + else + for (auto succ : ctx.program->blocks[block_index].logical_succs) + ctx.reg_block_uses[reg].emplace(succ); + } else { + ctx.reg_block_uses[reg].emplace(block_index); + } + } + } +} + +void +add_preserved_vgpr_spill(spill_preserved_ctx& ctx, PhysReg reg, + std::vector>& spills) +{ + assert(ctx.preserved_spill_offsets.find(reg) == ctx.preserved_spill_offsets.end()); + unsigned offset = ctx.next_preserved_offset; + ctx.next_preserved_offset += 4; + ctx.preserved_spill_offsets.emplace(reg, offset); + + spills.emplace_back(reg, offset); +} + +void +add_preserved_sgpr_spill(spill_preserved_ctx& ctx, PhysReg reg, + std::vector>& spills) +{ + unsigned lane; + + assert(ctx.preserved_spill_lanes.find(reg) == ctx.preserved_spill_lanes.end()); + lane = ctx.next_preserved_lane++; + ctx.preserved_spill_lanes.emplace(reg, lane); + + spills.emplace_back(reg, lane); + + unsigned vgpr_idx = lane / ctx.program->wave_size; + for (auto& spill_reg : ctx.sgpr_spill_regs) { + for (auto use : ctx.reg_block_uses[reg]) + ctx.reg_block_uses[spill_reg.advance(vgpr_idx * 4)].insert(use); + } +} + +void +emit_vgpr_spills_reloads(spill_preserved_ctx& ctx, Builder& bld, + std::vector>& spills, PhysReg stack_reg, + bool reload, bool linear) +{ + if (spills.empty()) + return; + + unsigned first_spill_offset = + DIV_ROUND_UP(ctx.program->config->scratch_bytes_per_wave, ctx.program->wave_size); + + int end_offset = (int)spills.back().second; + bool overflow = end_offset >= ctx.program->dev.scratch_global_offset_max; + if (overflow) { + for (auto& spill : spills) + spill.second -= first_spill_offset; + + if (ctx.program->gfx_level < GFX9) + first_spill_offset *= ctx.program->wave_size; + + bld.sop2(aco_opcode::s_addc_u32, Definition(stack_reg, s1), Definition(scc, s1), + Operand(stack_reg, s1), Operand::c32(first_spill_offset), Operand(scc, s1)); + if (ctx.program->gfx_level < GFX9) + bld.sop2(aco_opcode::s_addc_u32, Definition(stack_reg.advance(4), s1), Definition(scc, s1), + Operand(stack_reg.advance(4), s1), Operand::c32(0), Operand(scc, s1)); + bld.sopc(aco_opcode::s_bitcmp1_b32, Definition(scc, s1), Operand(stack_reg, s1), + Operand::c32(0)); + bld.sop1(aco_opcode::s_bitset0_b32, Definition(stack_reg, s1), Operand::c32(0), Operand(stack_reg, s1)); + } + + for (const auto& spill : spills) { + if (ctx.program->gfx_level >= GFX9) { + if (reload) + bld.scratch(aco_opcode::scratch_load_dword, + Definition(spill.first, linear ? v1.as_linear() : v1), Operand(v1), + Operand(stack_reg, s1), spill.second, + memory_sync_info(storage_vgpr_spill, semantic_private)); + else + bld.scratch(aco_opcode::scratch_store_dword, Operand(v1), Operand(stack_reg, s1), + Operand(spill.first, linear ? v1.as_linear() : v1), + spill.second, + memory_sync_info(storage_vgpr_spill, semantic_private)); + } else { + if (reload) { + Instruction* instr = bld.mubuf( + aco_opcode::buffer_load_dword, Definition(spill.first, linear ? v1.as_linear() : v1), + Operand(stack_reg, s4), Operand(v1), Operand::c32(0), spill.second, false); + instr->mubuf().sync = memory_sync_info(storage_vgpr_spill, semantic_private); + instr->mubuf().cache.value = ac_swizzled; + } else { + Instruction* instr = + bld.mubuf(aco_opcode::buffer_store_dword, Operand(stack_reg, s4), Operand(v1), + Operand::c32(0), Operand(spill.first, linear ? v1.as_linear() : v1), + spill.second, false); + instr->mubuf().sync = memory_sync_info(storage_vgpr_spill, semantic_private); + instr->mubuf().cache.value = ac_swizzled; + } + } + } + + if (overflow) { + bld.sop2(aco_opcode::s_addc_u32, Definition(stack_reg, s1), Definition(scc, s1), + Operand(stack_reg, s1), Operand::c32(-first_spill_offset), Operand(scc, s1)); + if (ctx.program->gfx_level < GFX9) + bld.sop2(aco_opcode::s_subb_u32, Definition(stack_reg.advance(4), s1), Definition(scc, s1), + Operand(stack_reg.advance(4), s1), Operand::c32(0), Operand(scc, s1)); + bld.sopc(aco_opcode::s_bitcmp1_b32, Definition(scc, s1), Operand(stack_reg, s1), + Operand::c32(0)); + bld.sop1(aco_opcode::s_bitset0_b32, Definition(stack_reg, s1), Operand::c32(0), + Operand(stack_reg, s1)); + } +} + +void +emit_sgpr_spills_reloads(spill_preserved_ctx& ctx, std::vector>& instructions, + std::vector>::iterator& insert_point, + PhysReg spill_reg, std::vector>& spills, + bool reload) +{ + std::vector> spill_instructions; + Builder bld(ctx.program, &spill_instructions); + + for (auto& spill : spills) { + unsigned vgpr_idx = spill.second / ctx.program->wave_size; + unsigned lane = spill.second % ctx.program->wave_size; + Operand vgpr_op = Operand(spill_reg.advance(vgpr_idx * 4), v1.as_linear()); + if (reload) + bld.pseudo(aco_opcode::p_reload, bld.def(s1, spill.first), vgpr_op, Operand::c32(lane)); + else + bld.pseudo(aco_opcode::p_spill, vgpr_op, Operand::c32(lane), Operand(spill.first, s1)); + } + + insert_point = instructions.insert(insert_point, std::move_iterator(spill_instructions.begin()), + std::move_iterator(spill_instructions.end())); +} + +void +emit_spills_reloads(spill_preserved_ctx& ctx, std::vector>& instructions, + std::vector>::iterator& insert_point, + std::vector>& spills, + std::vector>& lvgpr_spills, bool reload) +{ + auto spill_reload_compare = [](const auto& first, const auto& second) + { return first.second < second.second; }; + + std::sort(spills.begin(), spills.end(), spill_reload_compare); + std::sort(lvgpr_spills.begin(), lvgpr_spills.end(), spill_reload_compare); + + PhysReg stack_reg, exec_backup; + if ((*insert_point)->opcode == aco_opcode::p_startpgm || + (*insert_point)->opcode == aco_opcode::p_return) { + if ((*insert_point)->opcode == aco_opcode::p_startpgm) + stack_reg = (*insert_point)->definitions[0].physReg(); + else + stack_reg = (*insert_point)->operands[1].physReg(); + + /* We need to find an unused register to use for our exec backup. + * At p_startpgm, everything besides ABI-preserved SGPRs and SGPRs in the instruction + * definitions is unused, so we can stash our exec there, so find and use the first + * register pair matching these requirements. + */ + BITSET_DECLARE(unused_sgprs, 256); + + /* First, fill the bitset with all ABI-clobbered SGPRs. */ + memcpy(unused_sgprs, ctx.abi_preserved_regs, sizeof(unused_sgprs)); + BITSET_NOT(unused_sgprs); + + unsigned sgpr_limit = get_addr_regs_from_waves(ctx.program, ctx.program->min_waves).sgpr; + BITSET_CLEAR_RANGE(unused_sgprs, sgpr_limit, 255); + + /* p_startpgm has the used registers in its definitions and has no operands. + * p_return has the used registers in its operands and has no definitions. + */ + for (auto& def : (*insert_point)->definitions) { + if (def.regClass().type() == RegType::sgpr) { + BITSET_CLEAR_RANGE(unused_sgprs, def.physReg().reg(), + def.physReg().advance(def.bytes()) - 1); + } + } + for (auto& op : (*insert_point)->operands) { + if (op.regClass().type() == RegType::sgpr) { + BITSET_CLEAR_RANGE(unused_sgprs, op.physReg().reg(), + op.physReg().advance(op.bytes()) - 1); + } + } + + bool found_reg = false; + unsigned start_reg, end_reg; + BITSET_FOREACH_RANGE(start_reg, end_reg, unused_sgprs, 256) { + if (ctx.program->lane_mask.size() > 1 && (start_reg & 0x1)) + ++start_reg; + + if (start_reg + ctx.program->lane_mask.size() < end_reg) { + found_reg = true; + exec_backup = PhysReg{start_reg}; + break; + } + } + assert(found_reg && "aco/spill_preserved: No free space to store exec mask backup!"); + + unsigned num_sgprs = + get_sgpr_alloc(ctx.program, exec_backup.reg() + ctx.program->lane_mask.size()); + ctx.program->config->num_sgprs = MAX2(ctx.program->config->num_sgprs, num_sgprs); + ctx.program->max_reg_demand.update(RegisterDemand(0, num_sgprs)); + } else { + stack_reg = (*insert_point)->operands[1].physReg(); + exec_backup = (*insert_point)->definitions[0].physReg(); + } + + std::vector> spill_instructions; + Builder bld(ctx.program, &spill_instructions); + + emit_vgpr_spills_reloads(ctx, bld, spills, stack_reg, reload, false); + if (!lvgpr_spills.empty()) { + bld.sop1(Builder::s_or_saveexec, Definition(exec_backup, bld.lm), Definition(scc, s1), + Definition(exec, bld.lm), Operand::c64(UINT64_MAX), Operand(exec, bld.lm)); + emit_vgpr_spills_reloads(ctx, bld, lvgpr_spills, stack_reg, reload, true); + bld.sop1(Builder::WaveSpecificOpcode::s_mov, Definition(exec, bld.lm), + Operand(exec_backup, bld.lm)); + } + + if ((*insert_point)->opcode != aco_opcode::p_startpgm) + insert_point = instructions.erase(insert_point); + else + ++insert_point; + + insert_point = instructions.insert(insert_point, std::move_iterator(spill_instructions.begin()), + std::move_iterator(spill_instructions.end())); +} + +void +init_block_info(spill_preserved_ctx& ctx) +{ + Instruction* startpgm = ctx.program->blocks.front().instructions.front().get(); + + int cur_loop_header = -1; + for (int index = ctx.program->blocks.size() - 1; index >= 0;) { + const Block& block = ctx.program->blocks[index]; + + if (block.linear_succs.empty()) { + ctx.dom_info[index].logical_imm_postdom = block.index; + ctx.dom_info[index].linear_imm_postdom = block.index; + } else { + int new_logical_postdom = -1; + int new_linear_postdom = -1; + for (unsigned succ_idx : block.logical_succs) { + if ((int)ctx.dom_info[succ_idx].logical_imm_postdom == -1) { + assert(cur_loop_header == -1 || (int)succ_idx >= cur_loop_header); + if (cur_loop_header == -1) + cur_loop_header = (int)succ_idx; + continue; + } + + if (new_logical_postdom == -1) { + new_logical_postdom = (int)succ_idx; + continue; + } + + while ((int)succ_idx != new_logical_postdom) { + if ((int)succ_idx < new_logical_postdom) + succ_idx = ctx.dom_info[succ_idx].logical_imm_postdom; + if ((int)succ_idx > new_logical_postdom) + new_logical_postdom = (int)ctx.dom_info[new_logical_postdom].logical_imm_postdom; + } + } + + for (unsigned succ_idx : block.linear_succs) { + if ((int)ctx.dom_info[succ_idx].linear_imm_postdom == -1) { + assert(cur_loop_header == -1 || (int)succ_idx >= cur_loop_header); + if (cur_loop_header == -1) + cur_loop_header = (int)succ_idx; + continue; + } + + if (new_linear_postdom == -1) { + new_linear_postdom = (int)succ_idx; + continue; + } + + while ((int)succ_idx != new_linear_postdom) { + if ((int)succ_idx < new_linear_postdom) + succ_idx = ctx.dom_info[succ_idx].linear_imm_postdom; + if ((int)succ_idx > new_linear_postdom) + new_linear_postdom = (int)ctx.dom_info[new_linear_postdom].linear_imm_postdom; + } + } + + ctx.dom_info[index].logical_imm_postdom = new_logical_postdom; + ctx.dom_info[index].linear_imm_postdom = new_linear_postdom; + } + + bool seen_reload_vgpr = false; + for (auto& instr : block.instructions) { + if (instr->opcode == aco_opcode::p_startpgm && + ctx.program->callee_abi.block_size.preserved_size.sgpr) { + ctx.sgpr_spill_regs.emplace(instr->definitions.back().physReg()); + continue; + } else if (can_reload_at_instr(instr)) { + if (!instr->operands[0].isUndefined()) + ctx.sgpr_spill_regs.emplace(instr->operands[0].physReg()); + seen_reload_vgpr = true; + } + + add_instr(ctx, index, seen_reload_vgpr, instr, startpgm); + } + + /* Process predecessors of loop headers again, since post-dominance information of the header + * was not available the first time + */ + int next_idx = index - 1; + if (index == cur_loop_header) { + assert(block.kind & block_kind_loop_header); + for (auto pred : block.logical_preds) + if (ctx.dom_info[pred].logical_imm_postdom == -1u) + next_idx = std::max(next_idx, (int)pred); + for (auto pred : block.linear_preds) + if (ctx.dom_info[pred].linear_imm_postdom == -1u) + next_idx = std::max(next_idx, (int)pred); + cur_loop_header = -1; + } + index = next_idx; + } + + if (ctx.preserved_sgprs.size()) { + /* Figure out how many VGPRs we'll use to spill preserved SGPRs to. Manually add the linear + * VGPRs used to spill preserved SGPRs to the set of used linear VGPRs, as add_instr might not + * have seen any actual uses of these VGPRs yet. + */ + unsigned linear_vgprs_needed = DIV_ROUND_UP(ctx.preserved_sgprs.size(), ctx.program->wave_size); + + for (auto spill_reg : ctx.sgpr_spill_regs) { + for (unsigned i = 0; i < linear_vgprs_needed; ++i) + ctx.preserved_linear_vgprs.insert(spill_reg.advance(i * 4)); + } + } + /* If a register is used as both a VGPR and a linear VGPR, spill it as a linear VGPR because + * linear VGPR spilling backs up every lane. + */ + for (auto& lvgpr : ctx.preserved_linear_vgprs) + ctx.preserved_vgprs.erase(lvgpr); +} + +void +emit_call_spills(spill_preserved_ctx& ctx) +{ + std::set linear_vgprs; + std::vector> spills; + + unsigned max_scratch_offset = ctx.next_preserved_offset; + + for (auto& block : ctx.program->blocks) { + for (auto it = block.instructions.begin(); it != block.instructions.end();) { + auto& instr = *it; + + if (instr->opcode == aco_opcode::p_call) { + unsigned scratch_offset = ctx.next_preserved_offset; + BITSET_DECLARE(preserved_regs, 512); + instr->call().abi.preservedRegisters(preserved_regs); + for (auto& op : instr->operands) { + if (!op.isTemp() || !op.isPrecolored() || op.isClobbered()) + continue; + for (unsigned i = 0; i < op.size(); ++i) + BITSET_SET(preserved_regs, op.physReg().reg() + i); + } + for (auto& reg : linear_vgprs) { + if (BITSET_TEST(preserved_regs, reg.reg())) + continue; + spills.emplace_back(reg, scratch_offset); + scratch_offset += 4; + } + + max_scratch_offset = std::max(max_scratch_offset, scratch_offset); + + std::vector> spill_instructions; + Builder bld(ctx.program, &spill_instructions); + + PhysReg stack_reg = instr->operands[0].physReg(); + if (ctx.program->gfx_level < GFX9) + scratch_offset *= ctx.program->wave_size; + + emit_vgpr_spills_reloads(ctx, bld, spills, stack_reg, false, true); + + it = block.instructions.insert(it, std::move_iterator(spill_instructions.begin()), + std::move_iterator(spill_instructions.end())); + /* Move the iterator to directly after the call instruction */ + it += spill_instructions.size() + 1; + + spill_instructions.clear(); + + emit_vgpr_spills_reloads(ctx, bld, spills, stack_reg, true, true); + + it = block.instructions.insert(it, std::move_iterator(spill_instructions.begin()), + std::move_iterator(spill_instructions.end())); + + spills.clear(); + continue; + } else if (instr->opcode == aco_opcode::p_start_linear_vgpr) { + linear_vgprs.insert(instr->definitions[0].physReg()); + } else if (instr->opcode == aco_opcode::p_end_linear_vgpr) { + for (auto& op : instr->operands) + linear_vgprs.erase(op.physReg()); + } + ++it; + } + } + + ctx.next_preserved_offset = max_scratch_offset; +} + +void +emit_preserved_spills(spill_preserved_ctx& ctx) +{ + std::vector> spills; + std::vector> lvgpr_spills; + std::vector> sgpr_spills; + + if (ctx.program->callee_abi.block_size.preserved_size.sgpr == 0) + assert(ctx.preserved_sgprs.empty()); + + for (auto reg : ctx.preserved_vgprs) + add_preserved_vgpr_spill(ctx, reg, spills); + for (auto reg : ctx.preserved_linear_vgprs) + add_preserved_vgpr_spill(ctx, reg, lvgpr_spills); + for (auto reg : ctx.preserved_sgprs) + add_preserved_sgpr_spill(ctx, reg, sgpr_spills); + + /* The spiller inserts linear VGPRs for SGPR spilling in p_startpgm. Move past + * that to start spilling preserved SGPRs. + */ + auto startpgm = ctx.program->blocks.front().instructions.begin(); + auto sgpr_spill_reg = (*startpgm)->definitions.back().physReg(); + auto start_instr = std::next(startpgm); + emit_sgpr_spills_reloads(ctx, ctx.program->blocks.front().instructions, start_instr, + sgpr_spill_reg, sgpr_spills, false); + /* Move the iterator back to the p_startpgm. */ + start_instr = ctx.program->blocks.front().instructions.begin(); + + emit_spills_reloads(ctx, ctx.program->blocks.front().instructions, start_instr, spills, + lvgpr_spills, false); + + auto block_reloads = + std::vector>>(ctx.program->blocks.size()); + auto lvgpr_block_reloads = + std::vector>>(ctx.program->blocks.size()); + auto sgpr_block_reloads = + std::vector>>(ctx.program->blocks.size()); + + for (auto it = ctx.reg_block_uses.begin(); it != ctx.reg_block_uses.end();) { + bool is_linear_vgpr = + ctx.preserved_linear_vgprs.find(it->first) != ctx.preserved_linear_vgprs.end(); + bool is_sgpr = ctx.preserved_sgprs.find(it->first) != ctx.preserved_sgprs.end(); + bool is_linear = is_linear_vgpr || is_sgpr; + + if (!is_linear && ctx.preserved_vgprs.find(it->first) == ctx.preserved_vgprs.end()) { + it = ctx.reg_block_uses.erase(it); + continue; + } + + unsigned min_common_postdom = *it->second.begin(); + + for (auto succ_idx : it->second) { + while (succ_idx != min_common_postdom) { + if (min_common_postdom < succ_idx) { + min_common_postdom = is_linear + ? ctx.dom_info[min_common_postdom].linear_imm_postdom + : ctx.dom_info[min_common_postdom].logical_imm_postdom; + } else { + succ_idx = is_linear ? ctx.dom_info[succ_idx].linear_imm_postdom + : ctx.dom_info[succ_idx].logical_imm_postdom; + } + } + } + + while (std::find_if(ctx.program->blocks[min_common_postdom].instructions.rbegin(), + ctx.program->blocks[min_common_postdom].instructions.rend(), + can_reload_at_instr) == + ctx.program->blocks[min_common_postdom].instructions.rend()) + min_common_postdom = is_linear ? ctx.dom_info[min_common_postdom].linear_imm_postdom + : ctx.dom_info[min_common_postdom].logical_imm_postdom; + + if (is_linear_vgpr) { + lvgpr_block_reloads[min_common_postdom].emplace_back( + it->first, ctx.preserved_spill_offsets[it->first]); + } else if (is_sgpr) { + sgpr_block_reloads[min_common_postdom].emplace_back(it->first, + ctx.preserved_spill_lanes[it->first]); + } else { + block_reloads[min_common_postdom].emplace_back(it->first, + ctx.preserved_spill_offsets[it->first]); + } + + it = ctx.reg_block_uses.erase(it); + } + + for (unsigned i = 0; i < ctx.program->blocks.size(); ++i) { + auto instr_it = std::find_if(ctx.program->blocks[i].instructions.rbegin(), + ctx.program->blocks[i].instructions.rend(), can_reload_at_instr); + if (instr_it == ctx.program->blocks[i].instructions.rend()) { + assert(block_reloads[i].empty() && lvgpr_block_reloads[i].empty()); + continue; + } + std::optional spill_reg; + if (!(*instr_it)->operands[0].isUndefined()) + spill_reg = (*instr_it)->operands[0].physReg(); + + /* Insert VGPR spills after reload_preserved_vgpr, then insert SGPR spills before them. */ + auto end_instr = std::prev(instr_it.base()); + + emit_spills_reloads(ctx, ctx.program->blocks[i].instructions, end_instr, block_reloads[i], + lvgpr_block_reloads[i], true); + if (spill_reg) { + emit_sgpr_spills_reloads(ctx, ctx.program->blocks[i].instructions, end_instr, + *spill_reg, sgpr_block_reloads[i], true); + } + } +} + +void +spill_preserved(Program* program) +{ + if (!program->is_callee && !program->has_call) + return; + + spill_preserved_ctx ctx(program); + + bool has_return = + std::find_if(program->blocks.back().instructions.rbegin(), + program->blocks.back().instructions.rend(), [](const auto& instruction) + { return instruction->opcode == aco_opcode::p_return; }) != + program->blocks.back().instructions.rend(); + + if (program->is_callee && has_return) { + init_block_info(ctx); + emit_preserved_spills(ctx); + } + + if (program->has_call) + emit_call_spills(ctx); + + program->config->scratch_bytes_per_wave = ctx.next_preserved_offset * program->wave_size; +} +} // namespace aco diff --git a/src/amd/compiler/meson.build b/src/amd/compiler/meson.build index db32444a079..521d2771ffa 100644 --- a/src/amd/compiler/meson.build +++ b/src/amd/compiler/meson.build @@ -74,6 +74,7 @@ libaco_files = files( 'aco_scheduler.cpp', 'aco_scheduler_ilp.cpp', 'aco_spill.cpp', + 'aco_spill_preserved.cpp', 'aco_ssa_elimination.cpp', 'aco_statistics.cpp', 'aco_util.h',