From 8a3cc7200ef7db396b1bfeeb83f8b67e303236ab Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 19 Nov 2025 11:53:07 +0000 Subject: [PATCH] aco/sched: don't use previously unused preserved registers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_scheduler.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/amd/compiler/aco_scheduler.cpp b/src/amd/compiler/aco_scheduler.cpp index 95d62f6b0f1..1feb9b1cc90 100644 --- a/src/amd/compiler/aco_scheduler.cpp +++ b/src/amd/compiler/aco_scheduler.cpp @@ -1311,8 +1311,10 @@ schedule_program(Program* program) RegisterDemand demand; for (Block& block : program->blocks) demand.update(block.register_demand); - demand.vgpr += program->config->num_shared_vgprs / 2; - demand.update(program->fixed_reg_demand); + + RegisterDemand usage = demand; + usage.vgpr += program->config->num_shared_vgprs / 2; + usage.update(program->fixed_reg_demand); sched_ctx ctx; ctx.gfx_level = program->gfx_level; @@ -1326,7 +1328,7 @@ schedule_program(Program* program) /* If we already have less waves than the minimum, don't reduce them further. * Otherwise, sacrifice some waves and use more VGPRs, in order to improve scheduling. */ - int vgpr_demand = std::max(24, demand.vgpr) + 12 * reg_file_multiple; + int vgpr_demand = std::max(24, usage.vgpr) + 12 * reg_file_multiple; int target_waves = std::max(wave_minimum, program->dev.physical_vgprs / vgpr_demand); target_waves = max_suitable_waves(program, std::min(program->num_waves, target_waves)); assert(target_waves >= program->min_waves); @@ -1334,6 +1336,14 @@ schedule_program(Program* program) ctx.mv.max_registers = get_addr_regs_from_waves(program, target_waves); ctx.mv.max_registers.vgpr -= 2; + /* If this is a callee, don't use unneeded preserved VGPRs. */ + if (program->is_callee) { + RegisterDemand limit = get_addr_regs_from_waves(program, program->min_waves); + RegisterDemand max_clobbered_regs = program->callee_abi.numClobbered(limit); + ctx.mv.max_registers.vgpr = std::min(ctx.mv.max_registers.vgpr, max_clobbered_regs.vgpr); + ctx.mv.max_registers.update(demand); + } + /* VMEM_MAX_MOVES and such assume pre-GFX10 wave count */ ctx.occupancy_factor = target_waves / wave_factor;