From 2b20d568e08075e2d1345bed66755ea2a8c10a0a Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Fri, 21 Nov 2025 11:36:13 +0000 Subject: [PATCH] aco/ra: prefer clobbered registers in callees MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_register_allocation.cpp | 137 ++++++++++++------- 1 file changed, 87 insertions(+), 50 deletions(-) diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index bd320e62883..18ba8931128 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -159,6 +159,7 @@ struct ra_ctx { std::bitset<512> war_hint; PhysRegIterator rr_sgpr_it; PhysRegIterator rr_vgpr_it; + BITSET_DECLARE(preserved, 512) = {}; uint16_t sgpr_bounds; uint16_t vgpr_bounds; @@ -175,6 +176,9 @@ struct ra_ctx { phi_dummy.reset(create_instruction(aco_opcode::p_linear_phi, Format::PSEUDO, 0, 0)); limit = get_addr_regs_from_waves(program, program->min_waves); + if (program->is_callee) + program->callee_abi.preservedRegisters(preserved); + sgpr_bounds = program->max_reg_demand.sgpr; vgpr_bounds = program->max_reg_demand.vgpr; num_linear_vgprs = 0; @@ -1076,8 +1080,11 @@ update_renames(ra_ctx& ctx, RegisterFile& reg_file, std::vector& p } } -std::optional -get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info) +/* First value in the pair is the register. The second is the number of preserved registers used. + * This pair can be passed as the "best" parameter for another get_reg_simple() call. */ +std::optional> +get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info, + std::optional> best = {}) { PhysRegInterval bounds = info.bounds; uint32_t size = info.size; @@ -1088,9 +1095,9 @@ get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info) DefInfo new_info = info; new_info.stride = info.stride * 2; if (size % (stride * 2) == 0) { - std::optional res = get_reg_simple(ctx, reg_file, new_info); - if (res) - return res; + best = get_reg_simple(ctx, reg_file, new_info, best); + if (best && best->second == 0) + return best; } } @@ -1100,27 +1107,41 @@ get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info) assert(bounds.begin() < rr_it); assert(rr_it < bounds.end()); info.bounds = PhysRegInterval::from_until(rr_it.reg, bounds.hi()); - std::optional res = get_reg_simple(ctx, reg_file, info); - if (res) - return res; + best = get_reg_simple(ctx, reg_file, info, best); + if (best && best->second == 0) + return best; bounds = PhysRegInterval::from_until(bounds.lo(), rr_it.reg); } } - auto is_free = [&](PhysReg reg_index) - { return reg_file[reg_index] == 0 && !ctx.war_hint[reg_index]; }; - for (PhysRegInterval reg_win = {bounds.lo(), size}; reg_win.hi() <= bounds.hi(); reg_win += stride) { - if (std::all_of(reg_win.begin(), reg_win.end(), is_free)) { - if (stride == 1) { - PhysRegIterator new_rr_it{PhysReg{reg_win.lo() + size}}; - if (new_rr_it < bounds.end()) - rr_it = new_rr_it; + bool found = true; + unsigned num_preserved = 0; + for (PhysReg reg : reg_win) { + if (reg_file[reg] != 0 || ctx.war_hint[reg]) { + found = false; + break; } - adjust_max_used_regs(ctx, rc, reg_win.lo()); - return reg_win.lo(); + num_preserved += BITSET_TEST(ctx.preserved, reg); } + if (!found) + continue; + + if (!best || num_preserved < best->second) { + best.emplace(reg_win.lo(), num_preserved); + if (num_preserved == 0) + break; + } + } + if (best) { + if (stride == 1) { + PhysRegIterator new_rr_it{PhysReg{best->first + size}}; + if (new_rr_it < bounds.end()) + rr_it = new_rr_it; + } + adjust_max_used_regs(ctx, rc, best->first); + return best; } /* do this late because using the upper bytes of a register can require @@ -1148,7 +1169,8 @@ get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info) PhysReg res{entry.first}; res.reg_b += i; adjust_max_used_regs(ctx, rc, entry.first); - return res; + best.emplace(res, BITSET_TEST(ctx.preserved, entry.first)); + return best; } } } @@ -1218,7 +1240,8 @@ collect_vars_from_bitset(ra_ctx& ctx, RegisterFile& reg_file, const BITSET_DECLA return vars; } -std::optional +/* The second value in the pair is always zero (it exists for the caller's convenience). */ +std::optional> get_reg_for_create_vector_copy(ra_ctx& ctx, RegisterFile& reg_file, std::vector& parallelcopies, aco_ptr& instr, const PhysRegInterval def_reg, @@ -1231,7 +1254,7 @@ get_reg_for_create_vector_copy(ra_ctx& ctx, RegisterFile& reg_file, instr->operands[i].isKillBeforeDef()) { assert(!reg_file.test(reg, instr->operands[i].bytes())); if (info.rc.is_subdword() || reg.byte() == 0) - return reg; + return std::make_pair(reg, 0); else return {}; } @@ -1258,11 +1281,13 @@ get_reg_for_create_vector_copy(ra_ctx& ctx, RegisterFile& reg_file, instr->operands[i].regClass() == info.rc) { assignment& op = ctx.assignments[instr->operands[i].tempId()]; /* if everything matches, create parallelcopy for the killed operand */ - if (!intersects(def_reg, PhysRegInterval{op.reg, op.rc.size()}) && op.reg != scc && - reg_file.get_id(op.reg) == instr->operands[i].tempId()) { + PhysRegInterval reg_win{op.reg, op.rc.size()}; + if (!intersects(def_reg, reg_win) && op.reg != scc && + reg_file.get_id(op.reg) == instr->operands[i].tempId() && + !BITSET_TEST_RANGE(ctx.preserved, reg_win.lo(), reg_win.hi().reg() - 1)) { Definition pc_def = Definition(reg, info.rc); parallelcopies.emplace_back(instr->operands[i], pc_def); - return op.reg; + return std::make_pair(op.reg, 0); } } return {}; @@ -1286,7 +1311,7 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vector res; + std::optional> res; if (instr->opcode == aco_opcode::p_create_vector) { res = get_reg_for_create_vector_copy(ctx, reg_file, parallelcopies, instr, def_reg, info, id); @@ -1306,33 +1331,33 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vectorsecond > 0) && def_reg.hi() <= bounds.hi()) { unsigned stride = DIV_ROUND_UP(info.stride, 4); unsigned lo = (def_reg.hi() + stride - 1) & ~(stride - 1); info.bounds = PhysRegInterval::from_until(PhysReg{lo}, bounds.hi()); - res = get_reg_simple(ctx, reg_file, info); + res = get_reg_simple(ctx, reg_file, info, res); } } if (res) { /* mark the area as blocked */ - reg_file.block(*res, var.rc); + reg_file.block(res->first, var.rc); /* create parallelcopy pair (without definition id) */ Temp tmp = Temp(id, var.rc); Operand pc_op = Operand(tmp); pc_op.setFixed(var.reg); - Definition pc_def = Definition(*res, pc_op.regClass()); + Definition pc_def = Definition(res->first, pc_op.regClass()); parallelcopies.emplace_back(pc_op, pc_def); continue; } - std::optional>> best; + std::optional>> best; /* we use a sliding window to find potential positions */ unsigned stride = DIV_ROUND_UP(info.stride, 4); @@ -1342,11 +1367,14 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vector cost{num_moves, -num_vars}; + std::tuple cost{num_preserved, num_moves, -num_vars}; if (!best || cost < best->second) best.emplace(reg_win, cost); } @@ -1469,7 +1497,7 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, std::vector>> best; + std::optional>> best; /* we use a sliding window to check potential positions */ for (PhysRegInterval reg_win = {bounds.lo(), size}; reg_win.hi() <= bounds.hi(); @@ -1486,6 +1514,7 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, std::vector cost{num_moves, -num_vars, !aligned}; + std::tuple cost{num_preserved, num_moves, -num_vars, !aligned}; if (!best || cost < best->second) best.emplace(reg_win, cost); } @@ -1811,13 +1841,13 @@ get_reg_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, aco_ptr reg = get_reg_simple(ctx, reg_file, info); + std::optional> reg = get_reg_simple(ctx, reg_file, info); if (reg) { - reg->reg_b += our_offset; + reg->first.reg_b += our_offset; /* make sure to only use byte offset if the instruction supports it */ - if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, *reg, operand)) { - ctx.assignments[vec.parts[vec.index].tempId()].set_precolor_affinity(reg.value()); - return reg; + if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, reg->first, operand)) { + ctx.assignments[vec.parts[vec.index].tempId()].set_precolor_affinity(reg->first); + return reg->first; } } } @@ -2000,10 +2030,11 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, if (!ctx.policy.skip_optimistic_path && !ctx.policy.use_compact_relocate) { /* try to find space without live-range splits */ - res = get_reg_simple(ctx, reg_file, info); + std::optional> simple = get_reg_simple(ctx, reg_file, info); - if (res) - return *res; + /* Prefer moving to make space over using a preserved VGPR. */ + if (simple && (simple->second == 0 || temp.type() == RegType::sgpr)) + return simple->first; } if (!ctx.policy.use_compact_relocate) { @@ -2139,6 +2170,7 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, /* count variables to be moved and check "avoid" */ bool avoid = false; bool linear_vgpr = false; + unsigned num_preserved = 0; for (PhysReg j : reg_win) { if (reg_file[j] != 0) { if (reg_file[j] == 0xF0000000) { @@ -2153,6 +2185,7 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, } } avoid |= ctx.war_hint[j]; + num_preserved += BITSET_TEST(ctx.preserved, j.reg()); } /* we cannot split live ranges of linear vgprs */ @@ -2173,7 +2206,7 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, k += op.bytes(); } bool aligned = rc == RegClass::v4 && reg_win.lo() % 4 == 0; - if (k > num_moves || (!aligned && k == num_moves)) + if (k > num_moves || (!aligned && k == num_moves) || (k > 0 && num_preserved > 0)) continue; best_pos = reg_win.lo(); @@ -2187,9 +2220,10 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, return get_reg(ctx, reg_file, temp, parallelcopies, instr); } else if (num_moves > bytes) { DefInfo info(ctx, instr, rc, -1); - std::optional res = get_reg_simple(ctx, reg_file, info); - if (res) - return *res; + /* Prefer moving to make space over using a preserved VGPR. */ + std::optional> res = get_reg_simple(ctx, reg_file, info); + if (res && (res->second == 0 || temp.type() == RegType::sgpr)) + return res->first; } /* re-enable killed operands which are in the wrong position */ @@ -3960,9 +3994,12 @@ register_allocation(Program* program, ra_test_policy policy) } else if (i == 0) { RegClass vec_rc = RegClass::get(rc.type(), instr->operands[0].bytes()); DefInfo info(ctx, ctx.pseudo_dummy, vec_rc, -1); - std::optional res = get_reg_simple(ctx, register_file, info); - if (res && get_reg_specified(ctx, register_file, rc, instr, *res, -1)) - definition->setFixed(*res); + std::optional> res = + get_reg_simple(ctx, register_file, info); + /* Prefer using the normal get_reg() path over using a preserved VGPR. */ + if (res && (res->second == 0 || rc.type() == RegType::sgpr) && + get_reg_specified(ctx, register_file, rc, instr, res->first, -1)) + definition->setFixed(res->first); } else if (instr->definitions[i - 1].isFixed()) { reg = instr->definitions[i - 1].physReg(); reg.reg_b += instr->definitions[i - 1].bytes();