From 0ffbc30d7fa17d51281fb53ba65900f379b601b3 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Thu, 15 Jan 2026 15:25:55 +0000 Subject: [PATCH] aco: refactor spiller to use spills_needed variable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_ir.h | 12 ++++++ src/amd/compiler/aco_spill.cpp | 67 ++++++++++++++-------------------- 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index e9b60834873..86078db7765 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -1067,6 +1067,18 @@ struct RegisterDemand { return vgpr > other.vgpr || sgpr > other.sgpr; } + constexpr bool empty() const noexcept { return !exceeds(RegisterDemand()); } + + constexpr const int16_t& operator[](RegType type) const noexcept + { + return type == RegType::vgpr ? vgpr : sgpr; + } + + constexpr int16_t& operator[](RegType type) noexcept + { + return type == RegType::vgpr ? vgpr : sgpr; + } + constexpr RegisterDemand operator+(const Temp t) const noexcept { if (t.type() == RegType::sgpr) diff --git a/src/amd/compiler/aco_spill.cpp b/src/amd/compiler/aco_spill.cpp index bda6e18476a..6b57a47e235 100644 --- a/src/amd/compiler/aco_spill.cpp +++ b/src/amd/compiler/aco_spill.cpp @@ -208,6 +208,12 @@ gather_ssa_use_info(spill_ctx& ctx) } } +RegType +get_spill_regtype(RegisterDemand demand) +{ + return demand.vgpr > 0 ? RegType::vgpr : RegType::sgpr; +} + bool should_rematerialize(aco_ptr& instr) { @@ -345,7 +351,6 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) /* keep live-through variables spilled */ ctx.spills_entry[block_idx][spilled.first] = spilled.second; spilled_registers += spilled.first; - loop_demand -= spilled.first; } if (!ctx.loop.empty()) { /* If this is a nested loop, keep variables from the outer loop spilled. */ @@ -356,24 +361,18 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) if (live_in.count(spilled.first.id()) && ctx.spills_entry[block_idx].insert(spilled).second) { spilled_registers += spilled.first; - loop_demand -= spilled.first; } } } - /* select more live-through variables and constants */ - RegType type = RegType::vgpr; - while (loop_demand.exceeds(ctx.target_pressure) || - loop_call_spills.exceeds(spilled_registers)) { - /* if VGPR demand is low enough, select SGPRs */ - if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr && - loop_call_spills.vgpr <= spilled_registers.vgpr) - type = RegType::sgpr; - /* if SGPR demand is low enough, break */ - if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr && - loop_call_spills.sgpr <= spilled_registers.sgpr) - break; + loop_demand -= spilled_registers; + loop_call_spills -= spilled_registers; + /* select more live-through variables and constants */ + RegisterDemand spills_needed = loop_demand - ctx.target_pressure; + spills_needed.update(loop_call_spills); + while (!spills_needed.empty()) { + RegType type = get_spill_regtype(spills_needed); float score = 0.0; unsigned remat = 0; Temp to_spill; @@ -391,17 +390,15 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) } } - /* select SGPRs or break */ if (score == 0.0) { - if (type == RegType::sgpr) - break; - type = RegType::sgpr; + spills_needed[type] = 0; continue; } ctx.add_to_spills(to_spill, ctx.spills_entry[block_idx]); spilled_registers += to_spill; loop_demand -= to_spill; + spills_needed -= to_spill; } /* create new loop_info */ @@ -415,10 +412,11 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) /* if reg pressure is too high at beginning of loop, add variables with furthest use */ reg_pressure -= spilled_registers; - while (reg_pressure.exceeds(ctx.target_pressure)) { + spills_needed = reg_pressure - ctx.target_pressure; + while (!spills_needed.empty()) { float score = 0; Temp to_spill = Temp(); - type = reg_pressure.vgpr > ctx.target_pressure.vgpr ? RegType::vgpr : RegType::sgpr; + RegType type = get_spill_regtype(spills_needed); for (aco_ptr& phi : block->instructions) { if (!is_phi(phi)) break; @@ -434,7 +432,7 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) assert(to_spill != Temp()); ctx.add_to_spills(to_spill, ctx.spills_entry[block_idx]); spilled_registers += to_spill; - reg_pressure -= to_spill; + spills_needed -= to_spill; } return spilled_registers; @@ -550,13 +548,14 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) RegisterDemand reg_pressure = block->live_in_demand; reg_pressure -= spilled_registers; - while (reg_pressure.exceeds(ctx.target_pressure)) { + RegisterDemand spills_needed = reg_pressure - ctx.target_pressure; + while (!spills_needed.empty()) { assert(!partial_spills.empty()); std::map::iterator it = partial_spills.begin(); Temp to_spill = Temp(); bool is_partial_spill = false; float score = 0.0; - RegType type = reg_pressure.vgpr > ctx.target_pressure.vgpr ? RegType::vgpr : RegType::sgpr; + RegType type = get_spill_regtype(spills_needed); while (it != partial_spills.end()) { assert(!ctx.spills_entry[block_idx].count(it->first)); @@ -574,7 +573,7 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx) ctx.add_to_spills(to_spill, ctx.spills_entry[block_idx]); partial_spills.erase(to_spill); spilled_registers += to_spill; - reg_pressure -= to_spill; + spills_needed -= to_spill; } return spilled_registers; @@ -970,7 +969,7 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s /* if reg pressure is too high, spill variable with furthest next use */ while (true) { - bool needs_spill = (new_demand - spilled_registers).exceeds(ctx.target_pressure); + RegisterDemand spills_needed = new_demand - ctx.target_pressure; if (instr->isCall()) { RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit; @@ -980,10 +979,10 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s call_preserved_limit.vgpr = MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0); - needs_spill |= (instr->call().caller_preserved_demand - spilled_registers) - .exceeds(call_preserved_limit); + spills_needed.update(instr->call().caller_preserved_demand - call_preserved_limit); } - if (!needs_spill) + spills_needed -= spilled_registers; + if (spills_needed.empty()) break; float score = 0.0; @@ -994,17 +993,7 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s unsigned do_rematerialize = 0; unsigned avoid_respill = 0; - RegType type = RegType::sgpr; - bool spill_vgpr = new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr; - if (instr->isCall()) { - RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit; - call_preserved_limit.vgpr = - MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0); - spill_vgpr |= instr->call().caller_preserved_demand.vgpr - spilled_registers.vgpr > - call_preserved_limit.vgpr; - } - if (spill_vgpr) - type = RegType::vgpr; + RegType type = get_spill_regtype(spills_needed); for (unsigned t : ctx.program->live.live_in[block_idx]) { RegClass rc = ctx.program->temp_rc[t];