diff --git a/src/amd/compiler/aco_scheduler.cpp b/src/amd/compiler/aco_scheduler.cpp index 67e24bcfbb4..e11bb8d39cd 100644 --- a/src/amd/compiler/aco_scheduler.cpp +++ b/src/amd/compiler/aco_scheduler.cpp @@ -5,6 +5,7 @@ */ #include "aco_ir.h" +#include "aco_util.h" #include "common/amdgfxregs.h" @@ -80,6 +81,7 @@ struct UpwardsCursor { }; struct MoveState { + monotonic_buffer_resource m; RegisterDemand max_registers; Block* block; @@ -87,7 +89,8 @@ struct MoveState { bool improved_rar; std::vector depends_on; - std::vector RAR_dependencies; + aco::unordered_map rar_dependencies; /* temp-id -> index relative to insert_idx */ + MoveState() : rar_dependencies(m) {} /* for moving instructions before the current instruction to after it */ DownwardsCursor downwards_init(int current_idx, bool improved_rar); @@ -162,13 +165,13 @@ MoveState::downwards_init(int current_idx, bool improved_rar_) std::fill(depends_on.begin(), depends_on.end(), false); if (improved_rar) - std::fill(RAR_dependencies.begin(), RAR_dependencies.end(), false); + rar_dependencies.clear(); for (const Operand& op : current->operands) { if (op.isTemp()) { depends_on[op.tempId()] = true; if (improved_rar && op.isFirstKill()) - RAR_dependencies[op.tempId()] = true; + rar_dependencies[op.tempId()] = -1; } } @@ -181,16 +184,19 @@ MoveState::downwards_init(int current_idx, bool improved_rar_) } bool -check_dependencies(Instruction* instr, std::vector& def_dep, std::vector& op_dep) +check_dependencies(Instruction* instr, std::vector& def_dep, + aco::unordered_map& rar_deps, bool improved_rar) { for (const Definition& def : instr->definitions) { if (def.isTemp() && def_dep[def.tempId()]) return true; } for (const Operand& op : instr->operands) { - if (op.isTemp() && op_dep[op.tempId()]) { - // FIXME: account for difference in register pressure - return true; + if (op.isTemp()) { + if ((improved_rar && rar_deps.count(op.tempId())) || + (!improved_rar && def_dep[op.tempId()])) + // FIXME: account for difference in register pressure + return true; } } return false; @@ -203,8 +209,7 @@ MoveState::downwards_move(DownwardsCursor& cursor) aco_ptr& candidate = block->instructions[cursor.source_idx]; /* check if one of candidate's operands is killed by depending instruction */ - std::vector& RAR_deps = improved_rar ? RAR_dependencies : depends_on; - if (check_dependencies(candidate.get(), depends_on, RAR_deps)) + if (check_dependencies(candidate.get(), depends_on, rar_dependencies, improved_rar)) return move_fail_ssa; /* Check the new demand of the instructions being moved over: @@ -265,10 +270,10 @@ MoveState::downwards_move_clause(DownwardsCursor& cursor) int insert_idx = cursor.insert_idx_clause - 1; Instruction* instr = block->instructions[cursor.insert_idx_clause].get(); - /* Remove instruction operands from RAR_dependencies as the clause won't be moved further. */ + /* Remove instruction operands from rar_dependencies as the clause won't be moved further. */ for (const Operand& op : current->operands) { if (op.isTemp() && op.isFirstKill()) - RAR_dependencies[op.tempId()] = false; + rar_dependencies.erase(op.tempId()); } /* Check if one of candidates' operands is killed by depending instruction. */ @@ -276,7 +281,7 @@ MoveState::downwards_move_clause(DownwardsCursor& cursor) while (should_form_clause(block->instructions[clause_begin_idx].get(), instr)) { Instruction* candidate = block->instructions[clause_begin_idx--].get(); - if (check_dependencies(candidate, depends_on, RAR_dependencies)) + if (check_dependencies(candidate, depends_on, rar_dependencies, true)) return move_fail_ssa; max_clause_demand.update(candidate->register_demand); @@ -330,7 +335,7 @@ MoveState::downwards_skip(DownwardsCursor& cursor) if (op.isTemp()) { depends_on[op.tempId()] = true; if (improved_rar && op.isFirstKill()) - RAR_dependencies[op.tempId()] = true; + rar_dependencies[op.tempId()] = cursor.source_idx - cursor.insert_idx; } } cursor.total_demand.update(instr->register_demand); @@ -362,7 +367,7 @@ MoveState::upwards_init(int source_idx, bool improved_rar_) improved_rar = improved_rar_; std::fill(depends_on.begin(), depends_on.end(), false); - std::fill(RAR_dependencies.begin(), RAR_dependencies.end(), false); + rar_dependencies.clear(); for (const Definition& def : current->definitions) { if (def.isTemp()) @@ -405,7 +410,7 @@ MoveState::upwards_move(UpwardsCursor& cursor) /* check if candidate uses/kills an operand which is used by a dependency */ for (const Operand& op : instr->operands) { - if (op.isTemp() && (!improved_rar || op.isFirstKill()) && RAR_dependencies[op.tempId()]) + if (op.isTemp() && (!improved_rar || op.isFirstKill()) && rar_dependencies.count(op.tempId())) return move_fail_rar; } @@ -448,7 +453,7 @@ MoveState::upwards_skip(UpwardsCursor& cursor) } for (const Operand& op : instr->operands) { if (op.isTemp()) - RAR_dependencies[op.tempId()] = true; + rar_dependencies[op.tempId()] = cursor.source_idx - cursor.insert_idx; } cursor.total_demand.update(instr->register_demand); } @@ -1257,7 +1262,6 @@ schedule_program(Program* program) ctx.gfx_level = program->gfx_level; ctx.program = program; ctx.mv.depends_on.resize(program->peekAllocationId()); - ctx.mv.RAR_dependencies.resize(program->peekAllocationId()); const int wave_factor = program->gfx_level >= GFX10 ? 2 : 1; const int wave_minimum = std::max(program->min_waves, 4 * wave_factor);