Merge branch 'aco_misc_mem_ra_opts' into 'main'

aco: improve register allocation and phis with memory loads

See merge request mesa/mesa!38262
This commit is contained in:
Rhys Perry 2025-12-20 00:06:24 +00:00
commit 8f67192c4b
8 changed files with 380 additions and 31 deletions

View file

@ -126,6 +126,7 @@ amd_common_files = files(
'nir/ac_nir_meta_cs_blit.c',
'nir/ac_nir_meta_cs_clear_copy_buffer.c',
'nir/ac_nir_meta_ps_resolve.c',
'nir/ac_nir_opt_flip_if_for_mem_loads.c',
'nir/ac_nir_opt_pack_half.c',
'nir/ac_nir_opt_shared_append.c',
'nir/ac_nir_prerast_utils.c',

View file

@ -412,6 +412,9 @@ ac_nir_varying_expression_max_cost(nir_shader *producer, nir_shader *consumer);
bool
ac_nir_opt_shared_append(nir_shader *shader);
bool
ac_nir_opt_flip_if_for_mem_loads(nir_shader *shader);
bool
ac_nir_flag_smem_for_loads(nir_shader *shader, enum amd_gfx_level gfx_level, bool use_llvm);

View file

@ -0,0 +1,197 @@
/*
* Copyright 2025 Valve Corporation
*
* SPDX-License-Identifier: MIT
*/
/*
* This pass flips divergent branches if the then-side contains a memory load,
* and the else-side does not. This is useful because VMEM/LDS->VALU WaW on
* GFX11+ requires a waitcnt, even if the two writes have no lanes in common.
* By flipping the branch, it becomes a VALU->VMEM/LDS WaW, which requires no
* waitcnt.
*
* A typical case is a VMEM load and a constant:
* if (divergent_condition) {
* a = tex()
* } else {
* a = 0.0;
* }
* which becomes:
* if (!divergent_condition) {
* a = 0.0;
* } else {
* a = tex()
* }
*
* Note that it's best to run this before nir_opt_algebraic, to optimize out
* the inot, and after nir_opt_if, because opt_if_simplification can undo this
* optimization.
*/
#include "ac_nir.h"
#include "nir_builder.h"
enum {
is_vmem_lds = 1 << 0,
is_other = 1 << 1,
};
static unsigned
is_vmem_or_lds_load(nir_def *def, unsigned depth, unsigned begin, unsigned end)
{
if (nir_def_instr(def)->block->index < begin ||
nir_def_instr(def)->block->index > end ||
depth > 4)
return 0;
switch (nir_def_instr(def)->type) {
case nir_instr_type_alu: {
nir_alu_instr *alu = nir_def_as_alu(def);
/* ACO has an optimization to combine u2u32 into a load instruction, so treat it like a mov. */
if (!nir_op_is_vec_or_mov(alu->op) &&
!(alu->op == nir_op_u2u32 && alu->src[0].src.ssa->bit_size < 32))
return is_other;
unsigned res = 0;
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
res |= is_vmem_or_lds_load(alu->src[i].src.ssa, depth + 1, begin, end);
return res;
}
case nir_instr_type_phi: {
unsigned res = 0;
nir_foreach_phi_src (src, nir_def_as_phi(def))
res |= is_vmem_or_lds_load(src->src.ssa, depth + 1, begin, end);
return res;
}
case nir_instr_type_tex:
return is_vmem_lds;
case nir_instr_type_intrinsic: {
nir_intrinsic_instr *intrin = nir_def_as_intrinsic(def);
if (nir_intrinsic_has_access(intrin) && (nir_intrinsic_access(intrin) & ACCESS_SMEM_AMD))
return is_other;
switch (intrin->intrinsic) {
case nir_intrinsic_load_ubo:
case nir_intrinsic_load_ssbo:
case nir_intrinsic_load_global:
case nir_intrinsic_load_global_constant:
case nir_intrinsic_load_global_amd:
case nir_intrinsic_load_scratch:
case nir_intrinsic_load_shared:
case nir_intrinsic_load_constant:
case nir_intrinsic_bindless_image_load:
case nir_intrinsic_bindless_image_sparse_load:
case nir_intrinsic_bindless_image_fragment_mask_load_amd:
case nir_intrinsic_load_buffer_amd:
case nir_intrinsic_load_typed_buffer_amd:
case nir_intrinsic_ssbo_atomic:
case nir_intrinsic_ssbo_atomic_swap:
case nir_intrinsic_global_atomic:
case nir_intrinsic_global_atomic_swap:
case nir_intrinsic_global_atomic_amd:
case nir_intrinsic_global_atomic_swap_amd:
case nir_intrinsic_shared_atomic:
case nir_intrinsic_shared_atomic_swap:
case nir_intrinsic_bindless_image_atomic:
case nir_intrinsic_bindless_image_atomic_swap:
return is_vmem_lds;
default:
return is_other;
}
}
case nir_instr_type_undef:
return 0;
default:
return is_other;
}
}
static bool
opt_flip_if_for_mem_loads_impl(nir_function_impl*impl)
{
nir_metadata_require(impl, nir_metadata_block_index | nir_metadata_divergence);
nir_builder b = nir_builder_create(impl);
bool progress = false;
nir_foreach_block(block, impl) {
nir_if *nif = nir_block_get_following_if(block);
if (!nif || !nir_src_is_divergent(&nif->condition))
continue;
nir_block *merge = nir_cf_node_cf_tree_next(&nif->cf_node);
nir_block *then_block = nir_if_last_then_block(nif);
nir_block *else_block = nir_if_last_else_block(nif);
if (nir_block_ends_in_jump(then_block) || nir_block_ends_in_jump(else_block))
continue;
uint32_t then_first = nir_if_first_then_block(nif)->index;
uint32_t then_last = nir_if_last_then_block(nif)->index;
uint32_t else_first = nir_if_first_else_block(nif)->index;
uint32_t else_last = nir_if_last_else_block(nif)->index;
bool then_loads = false;
bool else_loads = false;
nir_foreach_phi(phi, merge) {
nir_phi_src *s_then = nir_phi_get_src_from_block(phi, then_block);
nir_phi_src *s_else = nir_phi_get_src_from_block(phi, else_block);
unsigned then_src = is_vmem_or_lds_load(s_then->src.ssa, 0, then_first, then_last);
unsigned else_src = is_vmem_or_lds_load(s_else->src.ssa, 0, else_first, else_last);
then_loads |=
(then_src & is_vmem_lds) &&
((else_src & is_other) ||
(!list_is_singular(&s_else->src.ssa->uses) && !nir_src_is_undef(s_else->src)) ||
nir_src_is_const(s_else->src));
else_loads |=
(else_src & is_vmem_lds) &&
((then_src & is_other) ||
(!list_is_singular(&s_then->src.ssa->uses) && !nir_src_is_undef(s_then->src)) ||
nir_src_is_const(s_then->src));
}
if (!then_loads || else_loads)
continue;
/* invert the condition */
nir_scalar cond = nir_get_scalar(nif->condition.ssa, 0);
nir_def *inv_cond = NULL;
b.cursor = nir_before_src(&nif->condition);
if (nir_scalar_is_intrinsic(cond) && nir_scalar_intrinsic_op(cond) == nir_intrinsic_inverse_ballot) {
nir_intrinsic_instr *intrin = nir_def_as_intrinsic(cond.def);
nir_scalar src = nir_scalar_resolved(intrin->src[0].ssa, 0);
if (nir_scalar_is_const(src))
inv_cond = nir_inverse_ballot_imm(&b, ~nir_scalar_as_uint(src), src.def->bit_size);
}
nir_src_rewrite(&nif->condition, inv_cond ? inv_cond : nir_inot(&b, nif->condition.ssa));
/* rewrite phi predecessors */
nir_foreach_phi(phi, merge) {
nir_foreach_phi_src(src, phi)
src->pred = src->pred == then_block ? else_block : then_block;
}
/* swap the cf_lists */
nir_cf_list then_list, else_list;
nir_cf_extract(&then_list, nir_before_cf_list(&nif->then_list),
nir_after_cf_list(&nif->then_list));
nir_cf_extract(&else_list, nir_before_cf_list(&nif->else_list),
nir_after_cf_list(&nif->else_list));
nir_cf_reinsert(&then_list, nir_before_cf_list(&nif->else_list));
nir_cf_reinsert(&else_list, nir_before_cf_list(&nif->then_list));
progress = true;
}
return nir_progress(progress, impl, 0);
}
bool
ac_nir_opt_flip_if_for_mem_loads(nir_shader *shader)
{
bool progress = false;
nir_foreach_function_impl(impl, shader)
progress |= opt_flip_if_for_mem_loads_impl(impl);
return progress;
}

View file

@ -58,6 +58,7 @@ struct assignment {
bool assigned : 1;
bool precolor_affinity : 1;
bool renamed : 1;
uint8_t weight : 2;
};
uint8_t _ = 0;
};
@ -78,6 +79,7 @@ struct assignment {
precolor_affinity = true;
reg = affinity_reg;
}
void update_weight(assignment& other) { weight = MAX2(weight, other.weight); }
};
/* Iterator type for making PhysRegInterval compatible with range-based for */
@ -1172,14 +1174,9 @@ find_vars(ra_ctx& ctx, const RegisterFile& reg_file, const PhysRegInterval reg_i
return vars;
}
/* collect variables from a register area and clear reg_file
* variables are sorted in decreasing size and
* increasing assigned register
*/
std::vector<unsigned>
collect_vars(ra_ctx& ctx, RegisterFile& reg_file, const PhysRegInterval reg_interval)
void
collect_vars(ra_ctx& ctx, RegisterFile& reg_file, std::vector<unsigned>& ids)
{
std::vector<unsigned> ids = find_vars(ctx, reg_file, reg_interval);
std::sort(ids.begin(), ids.end(),
[&](unsigned a, unsigned b)
{
@ -1193,6 +1190,17 @@ collect_vars(ra_ctx& ctx, RegisterFile& reg_file, const PhysRegInterval reg_inte
assignment& var = ctx.assignments[id];
reg_file.clear(var.reg, var.rc);
}
}
/* collect variables from a register area and clear reg_file
* variables are sorted in decreasing size and
* increasing assigned register
*/
std::vector<unsigned>
collect_vars(ra_ctx& ctx, RegisterFile& reg_file, const PhysRegInterval reg_interval)
{
std::vector<unsigned> ids = find_vars(ctx, reg_file, reg_interval);
collect_vars(ctx, reg_file, ids);
return ids;
}
@ -1938,6 +1946,45 @@ should_compact_linear_vgprs(ra_ctx& ctx, const RegisterFile& reg_file)
return max_vgpr_usage > get_reg_bounds(ctx, RegType::vgpr, false).size;
}
std::optional<PhysReg>
get_reg_affinity(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
std::vector<parallelcopy>& parallelcopies, aco_ptr<Instruction>& instr,
int operand_index, assignment& affinity)
{
/* check if the target register is blocked */
if (operand_index == -1 && reg_file.test(affinity.reg, temp.bytes())) {
const PhysRegInterval def_regs{PhysReg(affinity.reg.reg()), temp.size()};
std::vector<unsigned> vars = find_vars(ctx, reg_file, def_regs);
/* Bail if the cost of moving the blocking var is likely more expensive
* than assigning a different register.
*/
if (std::any_of(vars.begin(), vars.end(), [&](unsigned id) -> bool
{ return ctx.assignments[id].weight >= ctx.assignments[temp.id()].weight; }))
return {};
RegisterFile tmp_file(reg_file);
collect_vars(ctx, tmp_file, vars);
/* re-enable the killed operands, so that we don't move the blocking vars there */
if (!is_phi(instr))
tmp_file.fill_killed_operands(instr.get());
/* create parallelcopy to move blocking vars */
std::vector<parallelcopy> pc;
if (get_reg_specified(ctx, tmp_file, temp.regClass(), instr, affinity.reg, operand_index) &&
get_regs_for_copies(ctx, tmp_file, parallelcopies, vars, instr, def_regs)) {
parallelcopies.insert(parallelcopies.end(), pc.begin(), pc.end());
return affinity.reg;
}
} else if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, affinity.reg,
operand_index)) {
return affinity.reg;
}
return {};
}
PhysReg
get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
std::vector<parallelcopy>& parallelcopies, aco_ptr<Instruction>& instr,
@ -1964,11 +2011,15 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
}
}
std::optional<PhysReg> res;
if (ctx.assignments[temp.id()].affinity) {
assignment& affinity = ctx.assignments[ctx.assignments[temp.id()].affinity];
if (affinity.assigned) {
if (get_reg_specified(ctx, reg_file, temp.regClass(), instr, affinity.reg, operand_index))
return affinity.reg;
res =
get_reg_affinity(ctx, reg_file, temp, parallelcopies, instr, operand_index, affinity);
if (res)
return *res;
}
}
if (ctx.assignments[temp.id()].precolor_affinity) {
@ -1977,8 +2028,6 @@ get_reg(ra_ctx& ctx, const RegisterFile& reg_file, Temp temp,
return ctx.assignments[temp.id()].reg;
}
std::optional<PhysReg> res;
if (ctx.vectors.find(temp.id()) != ctx.vectors.end()) {
res = get_reg_vector(ctx, reg_file, temp, instr, operand_index);
if (res)
@ -2644,25 +2693,48 @@ get_regs_for_phis(ra_ctx& ctx, Block& block, RegisterFile& register_file,
if (definition.isFixed())
continue;
/* use affinity if available */
/* Preferring the more expensive to copy operands doesn't do much for logical phis on GFX11+
* because it creates a waitcnt anyway. */
bool avoid_heavy_copies =
ctx.program->gfx_level < GFX11 || phi->opcode == aco_opcode::p_linear_phi;
std::optional<assignment> affinity;
if (ctx.assignments[definition.tempId()].affinity &&
ctx.assignments[ctx.assignments[definition.tempId()].affinity].assigned) {
assignment& affinity = ctx.assignments[ctx.assignments[definition.tempId()].affinity];
assert(affinity.rc == definition.regClass());
if (get_reg_specified(ctx, register_file, definition.regClass(), phi, affinity.reg, -1)) {
definition.setFixed(affinity.reg);
affinity.emplace(ctx.assignments[ctx.assignments[definition.tempId()].affinity]);
}
small_vec<std::pair<unsigned, unsigned>, 4> operands;
/* by going backwards, we aim to avoid copies in else-blocks */
for (int i = phi->operands.size() - 1; i >= 0; i--) {
const Operand& op = phi->operands[i];
if (!op.isTemp() || !op.isFixed())
continue;
operands.emplace_back(ctx.assignments[op.tempId()].weight, i);
/* Don't use the affinity if it might end up creating a waitcnt. */
if (avoid_heavy_copies && affinity && op.physReg() != affinity->reg &&
ctx.assignments[op.tempId()].weight > 0)
affinity.reset();
}
/* use affinity if available */
if (affinity) {
assert(affinity->rc == definition.regClass());
if (get_reg_specified(ctx, register_file, definition.regClass(), phi, affinity->reg, -1)) {
definition.setFixed(affinity->reg);
register_file.fill(definition);
ctx.assignments[definition.tempId()].set(definition);
continue;
}
}
/* by going backwards, we aim to avoid copies in else-blocks */
for (int i = phi->operands.size() - 1; i >= 0; i--) {
const Operand& op = phi->operands[i];
if (!op.isTemp() || !op.isFixed())
continue;
/* If avoid_heavy_copies=false, then this is already sorted how we want it to be. */
if (avoid_heavy_copies)
std::sort(operands.begin(), operands.end(), std::greater());
for (auto pair : operands) {
const Operand& op = phi->operands[pair.second];
PhysReg reg = op.physReg();
if (get_reg_specified(ctx, register_file, definition.regClass(), phi, reg, -1)) {
definition.setFixed(reg);
@ -2689,6 +2761,14 @@ get_regs_for_phis(ra_ctx& ctx, Block& block, RegisterFile& register_file,
ctx.assignments[definition.tempId()].set(definition);
}
for (aco_ptr<Instruction>& phi : instructions) {
for (Operand op : phi->operands) {
if (!op.isTemp() || !op.isFixed() || op.physReg() != phi->definitions[0].physReg())
continue;
ctx.assignments[phi->definitions[0].tempId()].update_weight(ctx.assignments[op.tempId()]);
}
}
/* Provide a scratch register in case we need to preserve SCC */
if (has_linear_phis || block.kind & block_kind_loop_header) {
PhysReg scratch_reg = scc;
@ -3267,6 +3347,51 @@ get_affinities(ra_ctx& ctx)
ctx.vectors[vec[0].id()] = it->second;
}
}
/* If split definitions have affinities with other temporaries, try to allocate those temporaries
* as a vector. */
for (std::pair<uint32_t, Instruction *> pair : ctx.split_vectors) {
Instruction *split = pair.second;
vector_info info;
info.num_parts = split->definitions.size();
unsigned num_temps = 0;
for (unsigned i = 0; i < split->definitions.size(); i++) {
Definition def = split->definitions[i];
uint32_t id = ctx.assignments[def.tempId()].affinity;
if (!id || def.regClass().type() != split->operands[0].regClass().type())
continue;
if (!info.parts) {
info.parts =
(Operand*)ctx.memory.allocate(sizeof(Operand) * info.num_parts, alignof(Operand));
for (unsigned j = 0; j < split->definitions.size(); j++)
info.parts[j] = Operand(split->definitions[j].regClass());
}
info.parts[i] = Operand(Temp(id, ctx.program->temp_rc[id]));
num_temps++;
}
if (!num_temps)
continue;
for (unsigned i = 0; i < split->definitions.size(); i++) {
uint32_t id = info.parts[i].tempId();
if (!id)
continue;
/* If the new vector affinities only includes one temporary, only overwrite the old one if
* the new one is stronger. */
auto vec_it = ctx.vectors.find(id);
if (num_temps == 1 && vec_it != ctx.vectors.end() &&
(!vec_it->second.is_weak || info.is_weak))
continue;
info.index = i;
ctx.vectors[id] = info;
}
}
}
void
@ -3830,6 +3955,13 @@ register_allocation(Program* program, ra_test_policy policy)
optimize_encoding(ctx, register_file, instr);
if ((instr->isVMEM() || instr->isFlatLike()) && !instr->definitions.empty())
ctx.assignments[instr->definitions[0].tempId()].weight = 3;
if (instr->isSMEM() && !instr->definitions.empty())
ctx.assignments[instr->definitions[0].tempId()].weight = 2;
if (instr->isDS() && !instr->definitions.empty())
ctx.assignments[instr->definitions[0].tempId()].weight = 1;
auto tied_defs = get_tied_defs(instr.get());
handle_operands_tied_to_definitions(ctx, parallelcopy, instr, register_file, tied_defs);
@ -3932,6 +4064,8 @@ register_allocation(Program* program, ra_test_policy policy)
reg.reg_b += instr->definitions[j].bytes();
if (get_reg_specified(ctx, register_file, rc, instr, reg, -1)) {
definition->setFixed(reg);
ctx.assignments[definition->tempId()].update_weight(
ctx.assignments[instr->operands[0].tempId()]);
} else if (i == 0) {
RegClass vec_rc = RegClass::get(rc.type(), instr->operands[0].bytes());
DefInfo info(ctx, ctx.pseudo_dummy, vec_rc, -1);
@ -3953,13 +4087,24 @@ register_allocation(Program* program, ra_test_policy policy)
} else if (instr->opcode == aco_opcode::p_extract_vector) {
PhysReg reg = instr->operands[0].physReg();
reg.reg_b += definition->bytes() * instr->operands[1].constantValue();
if (get_reg_specified(ctx, register_file, definition->regClass(), instr, reg, -1))
if (get_reg_specified(ctx, register_file, definition->regClass(), instr, reg, -1)) {
definition->setFixed(reg);
ctx.assignments[definition->tempId()].update_weight(
ctx.assignments[instr->operands[0].tempId()]);
}
} else if (instr->opcode == aco_opcode::p_create_vector) {
PhysReg reg = get_reg_create_vector(ctx, register_file, definition->getTemp(),
parallelcopy, instr);
update_renames(ctx, register_file, parallelcopy, instr);
definition->setFixed(reg);
unsigned offset = 0;
for (const Operand& op : instr->operands) {
if (op.isTemp() && op.physReg() == reg.advance(offset))
ctx.assignments[definition->tempId()].update_weight(
ctx.assignments[op.tempId()]);
offset += op.bytes();
}
} else if (instr_info.classes[(int)instr->opcode] == instr_class::wmma &&
instr->operands[2].isTemp() && instr->operands[2].isKill() &&
instr->operands[2].regClass() == definition->regClass()) {

View file

@ -573,8 +573,12 @@ remove_entry(SchedILPContext& ctx, const Instruction* const instr, const uint32_
if (ctx.regs[reg].has_direct_dependency && ctx.regs[reg].direct_dependency == idx) {
ctx.regs[reg].has_direct_dependency = false;
if (!ctx.is_vopd) {
/* Do MAX2() so that the latency from both predecessors of a merge block are considered. */
if (BITSET_TEST(ctx.reg_has_latency, reg))
ctx.regs[reg].latency = MAX2(ctx.regs[reg].latency, latency);
else
ctx.regs[reg].latency = latency;
BITSET_SET(ctx.reg_has_latency, reg);
ctx.regs[reg].latency = latency;
}
}
}

View file

@ -441,14 +441,7 @@ BlockCycleEstimator::add(aco_ptr<Instruction>& instr)
mem_ops[i].push_back(cur_cycle + wait_info[i]);
}
/* This is inaccurate but shouldn't affect anything after waitcnt insertion.
* Before waitcnt insertion, this is necessary to consider memory operations.
*/
unsigned latency = 0;
for (unsigned i = 0; i < wait_type_num; i++)
latency = MAX2(latency, i == wait_type_vs ? 0 : wait_info[i]);
int32_t result_available = start + MAX2(perf.latency, (int32_t)latency);
int32_t result_available = start + perf.latency;
for (Definition& def : instr->definitions) {
int32_t* available = &reg_available[def.physReg().reg()];
for (unsigned i = 0; i < def.size(); i++)

View file

@ -520,6 +520,9 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_graphics_stat
NIR_PASS(_, stage->nir, ac_nir_lower_global_access);
NIR_PASS(_, stage->nir, nir_lower_int64);
if (gfx_level >= GFX11)
NIR_PASS(_, stage->nir, ac_nir_opt_flip_if_for_mem_loads);
radv_optimize_nir_algebraic(
stage->nir, io_to_mem || lowered_ngg || stage->stage == MESA_SHADER_COMPUTE || stage->stage == MESA_SHADER_TASK,
gfx_level >= GFX8, gfx_level);

View file

@ -1026,6 +1026,9 @@ static void run_late_optimization_and_lowering_passes(struct si_nir_shader_ctx *
};
NIR_PASS(_, nir, nir_opt_offsets, &offset_options);
if (sel->screen->info.gfx_level >= GFX11)
NIR_PASS(_, nir, ac_nir_opt_flip_if_for_mem_loads);
si_nir_late_opts(nir);
NIR_PASS(progress, nir, nir_opt_sink,