aco: unify get_addr_sgpr_from_waves() and get_addr_vgpr_from_waves() into one function

which returns the limit as RegisterDemand.

Also remove the unused get_extra_sgprs() from aco_ir.h.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33644>
This commit is contained in:
Daniel Schürmann 2024-07-24 13:12:28 +02:00 committed by Marge Bot
parent 6ea9443726
commit 52253da783
7 changed files with 26 additions and 39 deletions

View file

@ -12269,7 +12269,7 @@ get_arg_fixed(const struct ac_shader_args* args, struct ac_arg arg)
unsigned
load_vb_descs(Builder& bld, PhysReg dest, Operand base, unsigned start, unsigned max)
{
unsigned sgpr_limit = get_addr_sgpr_from_waves(bld.program, bld.program->min_waves);
unsigned sgpr_limit = get_addr_regs_from_waves(bld.program, bld.program->min_waves).sgpr;
unsigned count = MIN2((sgpr_limit - dest.reg()) / 4u, max);
for (unsigned i = 0; i < count;) {
unsigned size = 1u << util_logbase2(MIN2(count - i, 4));

View file

@ -2319,9 +2319,6 @@ int get_op_fixed_to_def(Instruction* instr);
RegisterDemand get_live_changes(Instruction* instr);
RegisterDemand get_temp_registers(Instruction* instr);
/* number of sgprs that need to be allocated but might notbe addressable as s0-s105 */
uint16_t get_extra_sgprs(Program* program);
/* adjust num_waves for workgroup size and LDS limits */
uint16_t max_suitable_waves(Program* program, uint16_t waves);
@ -2330,8 +2327,7 @@ uint16_t get_sgpr_alloc(Program* program, uint16_t addressable_sgprs);
uint16_t get_vgpr_alloc(Program* program, uint16_t addressable_vgprs);
/* return number of addressable sgprs/vgprs for max_waves */
uint16_t get_addr_sgpr_from_waves(Program* program, uint16_t max_waves);
uint16_t get_addr_vgpr_from_waves(Program* program, uint16_t max_waves);
RegisterDemand get_addr_regs_from_waves(Program* program, uint16_t waves);
bool uses_scratch(Program* program);

View file

@ -428,23 +428,19 @@ round_down(unsigned a, unsigned b)
return a - (a % b);
}
uint16_t
get_addr_sgpr_from_waves(Program* program, uint16_t waves)
RegisterDemand
get_addr_regs_from_waves(Program* program, uint16_t waves)
{
/* it's not possible to allocate more than 128 SGPRs */
uint16_t sgprs = std::min(program->dev.physical_sgprs / waves, 128);
sgprs = round_down(sgprs, program->dev.sgpr_alloc_granule);
sgprs -= get_extra_sgprs(program);
return std::min(sgprs, program->dev.sgpr_limit);
}
sgprs = round_down(sgprs, program->dev.sgpr_alloc_granule) - get_extra_sgprs(program);
sgprs = std::min(sgprs, program->dev.sgpr_limit);
uint16_t
get_addr_vgpr_from_waves(Program* program, uint16_t waves)
{
uint16_t vgprs = program->dev.physical_vgprs / waves;
vgprs = vgprs / program->dev.vgpr_alloc_granule * program->dev.vgpr_alloc_granule;
vgprs -= program->config->num_shared_vgprs / 2;
return std::min(vgprs, program->dev.vgpr_limit);
vgprs = std::min(vgprs, program->dev.vgpr_limit);
return RegisterDemand(vgprs, sgprs);
}
void
@ -496,11 +492,10 @@ void
update_vgpr_sgpr_demand(Program* program, const RegisterDemand new_demand)
{
assert(program->min_waves >= 1);
uint16_t sgpr_limit = get_addr_sgpr_from_waves(program, program->min_waves);
uint16_t vgpr_limit = get_addr_vgpr_from_waves(program, program->min_waves);
RegisterDemand limit = get_addr_regs_from_waves(program, program->min_waves);
/* this won't compile, register pressure reduction necessary */
if (new_demand.vgpr > vgpr_limit || new_demand.sgpr > sgpr_limit) {
if (new_demand.exceeds(limit)) {
program->num_waves = 0;
program->max_reg_demand = new_demand;
} else {
@ -513,8 +508,7 @@ update_vgpr_sgpr_demand(Program* program, const RegisterDemand new_demand)
/* Adjust for LDS and workgroup multiples and calculate max_reg_demand */
program->num_waves = max_suitable_waves(program, program->num_waves);
program->max_reg_demand.vgpr = get_addr_vgpr_from_waves(program, program->num_waves);
program->max_reg_demand.sgpr = get_addr_sgpr_from_waves(program, program->num_waves);
program->max_reg_demand = get_addr_regs_from_waves(program, program->num_waves);
}
}

View file

@ -112,8 +112,7 @@ struct ra_ctx {
aco_ptr<Instruction> phi_dummy;
uint16_t max_used_sgpr = 0;
uint16_t max_used_vgpr = 0;
uint16_t sgpr_limit;
uint16_t vgpr_limit;
RegisterDemand limit;
std::bitset<512> war_hint;
PhysRegIterator rr_sgpr_it;
PhysRegIterator rr_vgpr_it;
@ -131,8 +130,7 @@ struct ra_ctx {
{
pseudo_dummy.reset(create_instruction(aco_opcode::p_parallelcopy, Format::PSEUDO, 0, 0));
phi_dummy.reset(create_instruction(aco_opcode::p_linear_phi, Format::PSEUDO, 0, 0));
sgpr_limit = get_addr_sgpr_from_waves(program, program->min_waves);
vgpr_limit = get_addr_vgpr_from_waves(program, program->min_waves);
limit = get_addr_regs_from_waves(program, program->min_waves);
sgpr_bounds = program->max_reg_demand.sgpr;
vgpr_bounds = program->max_reg_demand.vgpr;
@ -797,7 +795,7 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
void
adjust_max_used_regs(ra_ctx& ctx, RegClass rc, unsigned reg)
{
uint16_t max_addressible_sgpr = ctx.sgpr_limit;
uint16_t max_addressible_sgpr = ctx.limit.sgpr;
unsigned size = rc.size();
if (rc.type() == RegType::vgpr) {
assert(reg >= 256);
@ -1421,13 +1419,13 @@ bool
increase_register_file(ra_ctx& ctx, RegClass rc)
{
if (rc.type() == RegType::vgpr && ctx.num_linear_vgprs == 0 &&
ctx.vgpr_bounds < ctx.vgpr_limit) {
ctx.vgpr_bounds < ctx.limit.vgpr) {
/* If vgpr_bounds is less than max_reg_demand.vgpr, this should be a no-op. */
update_vgpr_sgpr_demand(
ctx.program, RegisterDemand(ctx.vgpr_bounds + 1, ctx.program->max_reg_demand.sgpr));
ctx.vgpr_bounds = ctx.program->max_reg_demand.vgpr;
} else if (rc.type() == RegType::sgpr && ctx.program->max_reg_demand.sgpr < ctx.sgpr_limit) {
} else if (rc.type() == RegType::sgpr && ctx.program->max_reg_demand.sgpr < ctx.limit.sgpr) {
update_vgpr_sgpr_demand(
ctx.program, RegisterDemand(ctx.program->max_reg_demand.vgpr, ctx.sgpr_bounds + 1));

View file

@ -1260,8 +1260,8 @@ schedule_program(Program* program)
ctx.num_waves = max_suitable_waves(program, ctx.num_waves);
assert(ctx.num_waves >= program->min_waves);
ctx.mv.max_registers = {int16_t(get_addr_vgpr_from_waves(program, ctx.num_waves) - 2),
int16_t(get_addr_sgpr_from_waves(program, ctx.num_waves))};
ctx.mv.max_registers = get_addr_regs_from_waves(program, ctx.num_waves);
ctx.mv.max_registers.vgpr -= 2;
/* VMEM_MAX_MOVES and such assume pre-GFX10 wave count */
ctx.num_waves = std::max<uint16_t>(ctx.num_waves / wave_fac, 1);

View file

@ -1618,30 +1618,29 @@ spill(Program* program)
/* calculate target register demand */
const RegisterDemand demand = program->max_reg_demand; /* current max */
const uint16_t sgpr_limit = get_addr_sgpr_from_waves(program, program->min_waves);
const uint16_t vgpr_limit = get_addr_vgpr_from_waves(program, program->min_waves);
const RegisterDemand limit = get_addr_regs_from_waves(program, program->min_waves);
uint16_t extra_vgprs = 0;
uint16_t extra_sgprs = 0;
/* calculate extra VGPRs required for spilling SGPRs */
if (demand.sgpr > sgpr_limit) {
unsigned sgpr_spills = demand.sgpr - sgpr_limit;
if (demand.sgpr > limit.sgpr) {
unsigned sgpr_spills = demand.sgpr - limit.sgpr;
extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1;
}
/* add extra SGPRs required for spilling VGPRs */
if (demand.vgpr + extra_vgprs > vgpr_limit) {
if (demand.vgpr + extra_vgprs > limit.vgpr) {
if (program->gfx_level >= GFX9)
extra_sgprs = 1; /* SADDR */
else
extra_sgprs = 5; /* scratch_resource (s4) + scratch_offset (s1) */
if (demand.sgpr + extra_sgprs > sgpr_limit) {
if (demand.sgpr + extra_sgprs > limit.sgpr) {
/* re-calculate in case something has changed */
unsigned sgpr_spills = demand.sgpr + extra_sgprs - sgpr_limit;
unsigned sgpr_spills = demand.sgpr + extra_sgprs - limit.sgpr;
extra_vgprs = DIV_ROUND_UP(sgpr_spills * 2, program->wave_size) + 1;
}
}
/* the spiller has to target the following register demand */
const RegisterDemand target(vgpr_limit - extra_vgprs, sgpr_limit - extra_sgprs);
const RegisterDemand target(limit.vgpr - extra_vgprs, limit.sgpr - extra_sgprs);
/* initialize ctx */
spill_ctx ctx(target, program);

View file

@ -1391,7 +1391,7 @@ validate_ra(Program* program)
bool err = false;
aco::live_var_analysis(program);
std::vector<std::vector<Temp>> phi_sgpr_ops(program->blocks.size());
uint16_t sgpr_limit = get_addr_sgpr_from_waves(program, program->num_waves);
uint16_t sgpr_limit = get_addr_regs_from_waves(program, program->num_waves).sgpr;
std::vector<Assignment> assignments(program->peekAllocationId());
for (Block& block : program->blocks) {