aco: refactor spiller to use spills_needed variable

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39690>
This commit is contained in:
Rhys Perry 2026-01-15 15:25:55 +00:00 committed by Marge Bot
parent e60b49a3f6
commit 0ffbc30d7f
2 changed files with 40 additions and 39 deletions

View file

@ -1067,6 +1067,18 @@ struct RegisterDemand {
return vgpr > other.vgpr || sgpr > other.sgpr;
}
constexpr bool empty() const noexcept { return !exceeds(RegisterDemand()); }
constexpr const int16_t& operator[](RegType type) const noexcept
{
return type == RegType::vgpr ? vgpr : sgpr;
}
constexpr int16_t& operator[](RegType type) noexcept
{
return type == RegType::vgpr ? vgpr : sgpr;
}
constexpr RegisterDemand operator+(const Temp t) const noexcept
{
if (t.type() == RegType::sgpr)

View file

@ -208,6 +208,12 @@ gather_ssa_use_info(spill_ctx& ctx)
}
}
RegType
get_spill_regtype(RegisterDemand demand)
{
return demand.vgpr > 0 ? RegType::vgpr : RegType::sgpr;
}
bool
should_rematerialize(aco_ptr<Instruction>& instr)
{
@ -345,7 +351,6 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
/* keep live-through variables spilled */
ctx.spills_entry[block_idx][spilled.first] = spilled.second;
spilled_registers += spilled.first;
loop_demand -= spilled.first;
}
if (!ctx.loop.empty()) {
/* If this is a nested loop, keep variables from the outer loop spilled. */
@ -356,24 +361,18 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
if (live_in.count(spilled.first.id()) &&
ctx.spills_entry[block_idx].insert(spilled).second) {
spilled_registers += spilled.first;
loop_demand -= spilled.first;
}
}
}
/* select more live-through variables and constants */
RegType type = RegType::vgpr;
while (loop_demand.exceeds(ctx.target_pressure) ||
loop_call_spills.exceeds(spilled_registers)) {
/* if VGPR demand is low enough, select SGPRs */
if (type == RegType::vgpr && loop_demand.vgpr <= ctx.target_pressure.vgpr &&
loop_call_spills.vgpr <= spilled_registers.vgpr)
type = RegType::sgpr;
/* if SGPR demand is low enough, break */
if (type == RegType::sgpr && loop_demand.sgpr <= ctx.target_pressure.sgpr &&
loop_call_spills.sgpr <= spilled_registers.sgpr)
break;
loop_demand -= spilled_registers;
loop_call_spills -= spilled_registers;
/* select more live-through variables and constants */
RegisterDemand spills_needed = loop_demand - ctx.target_pressure;
spills_needed.update(loop_call_spills);
while (!spills_needed.empty()) {
RegType type = get_spill_regtype(spills_needed);
float score = 0.0;
unsigned remat = 0;
Temp to_spill;
@ -391,17 +390,15 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
}
}
/* select SGPRs or break */
if (score == 0.0) {
if (type == RegType::sgpr)
break;
type = RegType::sgpr;
spills_needed[type] = 0;
continue;
}
ctx.add_to_spills(to_spill, ctx.spills_entry[block_idx]);
spilled_registers += to_spill;
loop_demand -= to_spill;
spills_needed -= to_spill;
}
/* create new loop_info */
@ -415,10 +412,11 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
/* if reg pressure is too high at beginning of loop, add variables with furthest use */
reg_pressure -= spilled_registers;
while (reg_pressure.exceeds(ctx.target_pressure)) {
spills_needed = reg_pressure - ctx.target_pressure;
while (!spills_needed.empty()) {
float score = 0;
Temp to_spill = Temp();
type = reg_pressure.vgpr > ctx.target_pressure.vgpr ? RegType::vgpr : RegType::sgpr;
RegType type = get_spill_regtype(spills_needed);
for (aco_ptr<Instruction>& phi : block->instructions) {
if (!is_phi(phi))
break;
@ -434,7 +432,7 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
assert(to_spill != Temp());
ctx.add_to_spills(to_spill, ctx.spills_entry[block_idx]);
spilled_registers += to_spill;
reg_pressure -= to_spill;
spills_needed -= to_spill;
}
return spilled_registers;
@ -550,13 +548,14 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
RegisterDemand reg_pressure = block->live_in_demand;
reg_pressure -= spilled_registers;
while (reg_pressure.exceeds(ctx.target_pressure)) {
RegisterDemand spills_needed = reg_pressure - ctx.target_pressure;
while (!spills_needed.empty()) {
assert(!partial_spills.empty());
std::map<Temp, bool>::iterator it = partial_spills.begin();
Temp to_spill = Temp();
bool is_partial_spill = false;
float score = 0.0;
RegType type = reg_pressure.vgpr > ctx.target_pressure.vgpr ? RegType::vgpr : RegType::sgpr;
RegType type = get_spill_regtype(spills_needed);
while (it != partial_spills.end()) {
assert(!ctx.spills_entry[block_idx].count(it->first));
@ -574,7 +573,7 @@ init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_idx)
ctx.add_to_spills(to_spill, ctx.spills_entry[block_idx]);
partial_spills.erase(to_spill);
spilled_registers += to_spill;
reg_pressure -= to_spill;
spills_needed -= to_spill;
}
return spilled_registers;
@ -970,7 +969,7 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
/* if reg pressure is too high, spill variable with furthest next use */
while (true) {
bool needs_spill = (new_demand - spilled_registers).exceeds(ctx.target_pressure);
RegisterDemand spills_needed = new_demand - ctx.target_pressure;
if (instr->isCall()) {
RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit;
@ -980,10 +979,10 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
call_preserved_limit.vgpr =
MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0);
needs_spill |= (instr->call().caller_preserved_demand - spilled_registers)
.exceeds(call_preserved_limit);
spills_needed.update(instr->call().caller_preserved_demand - call_preserved_limit);
}
if (!needs_spill)
spills_needed -= spilled_registers;
if (spills_needed.empty())
break;
float score = 0.0;
@ -994,17 +993,7 @@ process_block(spill_ctx& ctx, unsigned block_idx, Block* block, RegisterDemand s
unsigned do_rematerialize = 0;
unsigned avoid_respill = 0;
RegType type = RegType::sgpr;
bool spill_vgpr = new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr;
if (instr->isCall()) {
RegisterDemand call_preserved_limit = instr->call().callee_preserved_limit;
call_preserved_limit.vgpr =
MAX2(call_preserved_limit.vgpr - (int16_t)ctx.extra_vgprs, 0);
spill_vgpr |= instr->call().caller_preserved_demand.vgpr - spilled_registers.vgpr >
call_preserved_limit.vgpr;
}
if (spill_vgpr)
type = RegType::vgpr;
RegType type = get_spill_regtype(spills_needed);
for (unsigned t : ctx.program->live.live_in[block_idx]) {
RegClass rc = ctx.program->temp_rc[t];