aco/ra: prefer clobbered registers in callees

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38679>
This commit is contained in:
Rhys Perry 2025-11-21 11:36:13 +00:00 committed by Marge Bot
parent 21b1118a08
commit 2b20d568e0

View file

@ -159,6 +159,7 @@ struct ra_ctx {
std::bitset<512> war_hint;
PhysRegIterator rr_sgpr_it;
PhysRegIterator rr_vgpr_it;
BITSET_DECLARE(preserved, 512) = {};
uint16_t sgpr_bounds;
uint16_t vgpr_bounds;
@ -175,6 +176,9 @@ struct ra_ctx {
phi_dummy.reset(create_instruction(aco_opcode::p_linear_phi, Format::PSEUDO, 0, 0));
limit = get_addr_regs_from_waves(program, program->min_waves);
if (program->is_callee)
program->callee_abi.preservedRegisters(preserved);
sgpr_bounds = program->max_reg_demand.sgpr;
vgpr_bounds = program->max_reg_demand.vgpr;
num_linear_vgprs = 0;
@ -1076,8 +1080,11 @@ update_renames(ra_ctx& ctx, RegisterFile& reg_file, std::vector<parallelcopy>& p
}
}
std::optional<PhysReg>
get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info)
/* First value in the pair is the register. The second is the number of preserved registers used.
* This pair can be passed as the "best" parameter for another get_reg_simple() call. */
std::optional<std::pair<PhysReg, uint32_t>>
get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info,
std::optional<std::pair<PhysReg, uint32_t>> best = {})
{
PhysRegInterval bounds = info.bounds;
uint32_t size = info.size;
@ -1088,9 +1095,9 @@ get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info)
DefInfo new_info = info;
new_info.stride = info.stride * 2;
if (size % (stride * 2) == 0) {
std::optional<PhysReg> res = get_reg_simple(ctx, reg_file, new_info);
if (res)
return res;
best = get_reg_simple(ctx, reg_file, new_info, best);
if (best && best->second == 0)
return best;
}
}
@ -1100,27 +1107,41 @@ get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info)
assert(bounds.begin() < rr_it);
assert(rr_it < bounds.end());
info.bounds = PhysRegInterval::from_until(rr_it.reg, bounds.hi());
std::optional<PhysReg> res = get_reg_simple(ctx, reg_file, info);
if (res)
return res;
best = get_reg_simple(ctx, reg_file, info, best);
if (best && best->second == 0)
return best;
bounds = PhysRegInterval::from_until(bounds.lo(), rr_it.reg);
}
}
auto is_free = [&](PhysReg reg_index)
{ return reg_file[reg_index] == 0 && !ctx.war_hint[reg_index]; };
for (PhysRegInterval reg_win = {bounds.lo(), size}; reg_win.hi() <= bounds.hi();
reg_win += stride) {
if (std::all_of(reg_win.begin(), reg_win.end(), is_free)) {
if (stride == 1) {
PhysRegIterator new_rr_it{PhysReg{reg_win.lo() + size}};
if (new_rr_it < bounds.end())
rr_it = new_rr_it;
bool found = true;
unsigned num_preserved = 0;
for (PhysReg reg : reg_win) {
if (reg_file[reg] != 0 || ctx.war_hint[reg]) {
found = false;
break;
}
adjust_max_used_regs(ctx, rc, reg_win.lo());
return reg_win.lo();
num_preserved += BITSET_TEST(ctx.preserved, reg);
}
if (!found)
continue;
if (!best || num_preserved < best->second) {
best.emplace(reg_win.lo(), num_preserved);
if (num_preserved == 0)
break;
}
}
if (best) {
if (stride == 1) {
PhysRegIterator new_rr_it{PhysReg{best->first + size}};
if (new_rr_it < bounds.end())
rr_it = new_rr_it;
}
adjust_max_used_regs(ctx, rc, best->first);
return best;
}
/* do this late because using the upper bytes of a register can require
@ -1148,7 +1169,8 @@ get_reg_simple(ra_ctx& ctx, const RegisterFile& reg_file, DefInfo info)
PhysReg res{entry.first};
res.reg_b += i;
adjust_max_used_regs(ctx, rc, entry.first);
return res;
best.emplace(res, BITSET_TEST(ctx.preserved, entry.first));
return best;
}
}
}
@ -1218,7 +1240,8 @@ collect_vars_from_bitset(ra_ctx& ctx, RegisterFile& reg_file, const BITSET_DECLA
return vars;
}
std::optional<PhysReg>
/* The second value in the pair is always zero (it exists for the caller's convenience). */
std::optional<std::pair<PhysReg, uint32_t>>
get_reg_for_create_vector_copy(ra_ctx& ctx, RegisterFile& reg_file,
std::vector<parallelcopy>& parallelcopies,
aco_ptr<Instruction>& instr, const PhysRegInterval def_reg,
@ -1231,7 +1254,7 @@ get_reg_for_create_vector_copy(ra_ctx& ctx, RegisterFile& reg_file,
instr->operands[i].isKillBeforeDef()) {
assert(!reg_file.test(reg, instr->operands[i].bytes()));
if (info.rc.is_subdword() || reg.byte() == 0)
return reg;
return std::make_pair(reg, 0);
else
return {};
}
@ -1258,11 +1281,13 @@ get_reg_for_create_vector_copy(ra_ctx& ctx, RegisterFile& reg_file,
instr->operands[i].regClass() == info.rc) {
assignment& op = ctx.assignments[instr->operands[i].tempId()];
/* if everything matches, create parallelcopy for the killed operand */
if (!intersects(def_reg, PhysRegInterval{op.reg, op.rc.size()}) && op.reg != scc &&
reg_file.get_id(op.reg) == instr->operands[i].tempId()) {
PhysRegInterval reg_win{op.reg, op.rc.size()};
if (!intersects(def_reg, reg_win) && op.reg != scc &&
reg_file.get_id(op.reg) == instr->operands[i].tempId() &&
!BITSET_TEST_RANGE(ctx.preserved, reg_win.lo(), reg_win.hi().reg() - 1)) {
Definition pc_def = Definition(reg, info.rc);
parallelcopies.emplace_back(instr->operands[i], pc_def);
return op.reg;
return std::make_pair(op.reg, 0);
}
}
return {};
@ -1286,7 +1311,7 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vector<parallelcop
/* check if this is a dead operand, then we can re-use the space from the definition
* also use the correct stride for sub-dword operands */
bool is_dead_operand = false;
std::optional<PhysReg> res;
std::optional<std::pair<PhysReg, uint32_t>> res;
if (instr->opcode == aco_opcode::p_create_vector) {
res =
get_reg_for_create_vector_copy(ctx, reg_file, parallelcopies, instr, def_reg, info, id);
@ -1306,33 +1331,33 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vector<parallelcop
if (!res && !def_reg.size) {
/* If this is before definitions are handled, def_reg may be an empty interval. */
info.bounds = bounds;
res = get_reg_simple(ctx, reg_file, info);
res = get_reg_simple(ctx, reg_file, info, res);
} else if (!res) {
/* Try to find space within the bounds but outside of the definition */
info.bounds = PhysRegInterval::from_until(bounds.lo(), MIN2(def_reg.lo(), bounds.hi()));
res = get_reg_simple(ctx, reg_file, info);
if (!res && def_reg.hi() <= bounds.hi()) {
if ((!res || res->second > 0) && def_reg.hi() <= bounds.hi()) {
unsigned stride = DIV_ROUND_UP(info.stride, 4);
unsigned lo = (def_reg.hi() + stride - 1) & ~(stride - 1);
info.bounds = PhysRegInterval::from_until(PhysReg{lo}, bounds.hi());
res = get_reg_simple(ctx, reg_file, info);
res = get_reg_simple(ctx, reg_file, info, res);
}
}
if (res) {
/* mark the area as blocked */
reg_file.block(*res, var.rc);
reg_file.block(res->first, var.rc);
/* create parallelcopy pair (without definition id) */
Temp tmp = Temp(id, var.rc);
Operand pc_op = Operand(tmp);
pc_op.setFixed(var.reg);
Definition pc_def = Definition(*res, pc_op.regClass());
Definition pc_def = Definition(res->first, pc_op.regClass());
parallelcopies.emplace_back(pc_op, pc_def);
continue;
}
std::optional<std::pair<PhysRegInterval, std::tuple<unsigned, int>>> best;
std::optional<std::pair<PhysRegInterval, std::tuple<unsigned, unsigned, int>>> best;
/* we use a sliding window to find potential positions */
unsigned stride = DIV_ROUND_UP(info.stride, 4);
@ -1342,11 +1367,14 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vector<parallelcop
continue;
/* second, check that no element is larger than the currently processed one */
unsigned num_preserved = 0;
unsigned num_moves = 0;
int num_vars = 0;
unsigned last_var = 0;
bool found = true;
for (PhysReg j : reg_win) {
num_preserved += BITSET_TEST(ctx.preserved, j.reg());
if (reg_file[j] == 0 || reg_file[j] == last_var)
continue;
@ -1384,7 +1412,7 @@ get_regs_for_copies(ra_ctx& ctx, RegisterFile& reg_file, std::vector<parallelcop
if (!found)
continue;
std::tuple<unsigned, int> cost{num_moves, -num_vars};
std::tuple<unsigned, unsigned, int> cost{num_preserved, num_moves, -num_vars};
if (!best || cost < best->second)
best.emplace(reg_win, cost);
}
@ -1469,7 +1497,7 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, std::vector<parallelcopy
op_moves = size - (regs_free - killed_ops);
/* find the best position to place the definition */
std::optional<std::pair<PhysRegInterval, std::tuple<unsigned, int, bool>>> best;
std::optional<std::pair<PhysRegInterval, std::tuple<unsigned, unsigned, int, bool>>> best;
/* we use a sliding window to check potential positions */
for (PhysRegInterval reg_win = {bounds.lo(), size}; reg_win.hi() <= bounds.hi();
@ -1486,6 +1514,7 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, std::vector<parallelcopy
/* second, check that we have at most k=num_moves elements in the window
* and no element is larger than the currently processed one */
unsigned num_preserved = 0;
unsigned num_moves = op_moves;
int num_vars = 0;
unsigned remaining_op_moves = op_moves;
@ -1493,6 +1522,8 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, std::vector<parallelcopy
bool found = true;
bool aligned = rc == RegClass::v4 && reg_win.lo() % 4 == 0;
for (const PhysReg j : reg_win) {
num_preserved += BITSET_TEST(ctx.preserved, j.reg());
/* dead operands effectively reduce the number of estimated moves */
if (is_killed_operand[j & 0xFF]) {
if (remaining_op_moves) {
@ -1530,11 +1561,10 @@ get_reg_impl(ra_ctx& ctx, const RegisterFile& reg_file, std::vector<parallelcopy
num_vars++;
last_var = reg_file[j];
}
if (!found)
continue;
std::tuple<unsigned, int, bool> cost{num_moves, -num_vars, !aligned};
std::tuple<unsigned, unsigned, int, bool> cost{num_preserved, num_moves, -num_vars, !aligned};
if (!best || cost < best->second)
best.emplace(reg_win, cost);
}
@ -1811,13 +1841,13 @@ get_reg_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp, aco_ptr<Ins
*/
RegClass vec_rc = RegClass::get(temp.type(), their_offset);
DefInfo info(ctx, ctx.pseudo_dummy, vec_rc, -1);
std::optional<PhysReg> reg = get_reg_simple(ctx, reg_file, info);
std::optional<std::pair<PhysReg, uint32_t>> reg = get_reg_simple(ctx, reg_file, info);
if (reg) {
reg->reg_b += our_offset;
reg->first.reg_b += our_offset;
/* make sure to only use byte offset if the instruction supports it */
if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, *reg, operand)) {
ctx.assignments[vec.parts[vec.index].tempId()].set_precolor_affinity(reg.value());
return reg;
if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, reg->first, operand)) {
ctx.assignments[vec.parts[vec.index].tempId()].set_precolor_affinity(reg->first);
return reg->first;
}
}
}
@ -2000,10 +2030,11 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
if (!ctx.policy.skip_optimistic_path && !ctx.policy.use_compact_relocate) {
/* try to find space without live-range splits */
res = get_reg_simple(ctx, reg_file, info);
std::optional<std::pair<PhysReg, uint32_t>> simple = get_reg_simple(ctx, reg_file, info);
if (res)
return *res;
/* Prefer moving to make space over using a preserved VGPR. */
if (simple && (simple->second == 0 || temp.type() == RegType::sgpr))
return simple->first;
}
if (!ctx.policy.use_compact_relocate) {
@ -2139,6 +2170,7 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
/* count variables to be moved and check "avoid" */
bool avoid = false;
bool linear_vgpr = false;
unsigned num_preserved = 0;
for (PhysReg j : reg_win) {
if (reg_file[j] != 0) {
if (reg_file[j] == 0xF0000000) {
@ -2153,6 +2185,7 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
}
}
avoid |= ctx.war_hint[j];
num_preserved += BITSET_TEST(ctx.preserved, j.reg());
}
/* we cannot split live ranges of linear vgprs */
@ -2173,7 +2206,7 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
k += op.bytes();
}
bool aligned = rc == RegClass::v4 && reg_win.lo() % 4 == 0;
if (k > num_moves || (!aligned && k == num_moves))
if (k > num_moves || (!aligned && k == num_moves) || (k > 0 && num_preserved > 0))
continue;
best_pos = reg_win.lo();
@ -2187,9 +2220,10 @@ get_reg_create_vector(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
return get_reg(ctx, reg_file, temp, parallelcopies, instr);
} else if (num_moves > bytes) {
DefInfo info(ctx, instr, rc, -1);
std::optional<PhysReg> res = get_reg_simple(ctx, reg_file, info);
if (res)
return *res;
/* Prefer moving to make space over using a preserved VGPR. */
std::optional<std::pair<PhysReg, uint32_t>> res = get_reg_simple(ctx, reg_file, info);
if (res && (res->second == 0 || temp.type() == RegType::sgpr))
return res->first;
}
/* re-enable killed operands which are in the wrong position */
@ -3960,9 +3994,12 @@ register_allocation(Program* program, ra_test_policy policy)
} else if (i == 0) {
RegClass vec_rc = RegClass::get(rc.type(), instr->operands[0].bytes());
DefInfo info(ctx, ctx.pseudo_dummy, vec_rc, -1);
std::optional<PhysReg> res = get_reg_simple(ctx, register_file, info);
if (res && get_reg_specified(ctx, register_file, rc, instr, *res, -1))
definition->setFixed(*res);
std::optional<std::pair<PhysReg, uint32_t>> res =
get_reg_simple(ctx, register_file, info);
/* Prefer using the normal get_reg() path over using a preserved VGPR. */
if (res && (res->second == 0 || rc.type() == RegType::sgpr) &&
get_reg_specified(ctx, register_file, rc, instr, res->first, -1))
definition->setFixed(res->first);
} else if (instr->definitions[i - 1].isFixed()) {
reg = instr->definitions[i - 1].physReg();
reg.reg_b += instr->definitions[i - 1].bytes();