diff --git a/src/amd/compiler/aco_lower_to_cssa.cpp b/src/amd/compiler/aco_lower_to_cssa.cpp index 74f6b9ac67f..3725c72881a 100644 --- a/src/amd/compiler/aco_lower_to_cssa.cpp +++ b/src/amd/compiler/aco_lower_to_cssa.cpp @@ -48,7 +48,6 @@ struct merge_node { struct cssa_ctx { Program* program; - std::vector& live_out; /* live-out sets per block */ std::vector> parallelcopies; /* copies per block */ std::vector merge_sets; /* each vector is one (ordered) merge set */ std::unordered_map merge_node_table; /* tempid -> merge node */ @@ -108,11 +107,6 @@ collect_parallelcopies(cssa_ctx& ctx) set.emplace_back(tmp); ctx.merge_node_table[tmp.id()] = {op, index, preds[i]}; - /* update the liveness information */ - if (op.isKill()) - ctx.live_out[preds[i]].erase(op.tempId()); - ctx.live_out[preds[i]].insert(tmp.id()); - has_preheader_copy |= i == 0 && block.kind & block_kind_loop_header; } @@ -171,23 +165,24 @@ intersects(cssa_ctx& ctx, Temp var, Temp parent) assert(node_var.index != node_parent.index); uint32_t block_idx = node_var.defined_at; - /* if the parent is live-out at the definition block of var, they intersect */ - bool parent_live = ctx.live_out[block_idx].count(parent.id()); - if (parent_live) - return true; - - /* parent is defined in a different block than var */ + /* if parent is defined in a different block than var */ if (node_parent.defined_at < node_var.defined_at) { /* if the parent is not live-in, they don't interfere */ - Block::edge_vec& preds = var.type() == RegType::vgpr - ? ctx.program->blocks[block_idx].logical_preds - : ctx.program->blocks[block_idx].linear_preds; - for (uint32_t pred : preds) { - if (!ctx.live_out[pred].count(parent.id())) - return false; - } + if (!ctx.program->live.live_in[block_idx].count(parent.id())) + return false; } + /* if the parent is live-out at the definition block of var, they intersect */ + Block::edge_vec& succs = var.type() == RegType::vgpr + ? ctx.program->blocks[block_idx].logical_succs + : ctx.program->blocks[block_idx].linear_succs; + + bool parent_live = std::any_of(succs.begin(), succs.end(), + [&](unsigned succ) + { return ctx.program->live.live_in[succ].count(parent.id()); }); + if (parent_live) + return true; + for (const copy& cp : ctx.parallelcopies[block_idx]) { /* if var is defined at the edge, they don't intersect */ if (cp.def.getTemp() == var) @@ -334,12 +329,10 @@ try_coalesce_copy(cssa_ctx& ctx, copy copy, uint32_t block_idx) merge_node& op_node = ctx.merge_node_table[copy.op.tempId()]; if (op_node.defined_at == -1u) { /* find defining block of operand */ - uint32_t pred = block_idx; - do { - block_idx = pred; - pred = copy.op.regClass().type() == RegType::vgpr ? ctx.program->blocks[pred].logical_idom - : ctx.program->blocks[pred].linear_idom; - } while (block_idx != pred && ctx.live_out[pred].count(copy.op.tempId())); + while (ctx.program->live.live_in[block_idx].count(copy.op.tempId())) + block_idx = copy.op.regClass().type() == RegType::vgpr + ? ctx.program->blocks[block_idx].logical_idom + : ctx.program->blocks[block_idx].linear_idom; op_node.defined_at = block_idx; op_node.value = copy.op; } @@ -441,9 +434,6 @@ emit_parallelcopies(cssa_ctx& ctx) for (const copy& cp : ctx.parallelcopies[i]) { if (try_coalesce_copy(ctx, cp, i)) { renames.emplace(cp.def.tempId(), cp.op); - /* update liveness info */ - ctx.live_out[i].erase(cp.def.tempId()); - ctx.live_out[i].insert(cp.op.tempId()); } else { uint32_t read_idx = -1u; if (cp.op.isTemp()) @@ -517,7 +507,7 @@ void lower_to_cssa(Program* program) { reindex_ssa(program, true); - cssa_ctx ctx = {program, program->live.live_out}; + cssa_ctx ctx = {program}; collect_parallelcopies(ctx); emit_parallelcopies(ctx); diff --git a/src/amd/compiler/aco_reindex_ssa.cpp b/src/amd/compiler/aco_reindex_ssa.cpp index 44912fead1c..5ab58fb92a9 100644 --- a/src/amd/compiler/aco_reindex_ssa.cpp +++ b/src/amd/compiler/aco_reindex_ssa.cpp @@ -90,6 +90,12 @@ reindex_ssa(Program* program, bool update_live_out = false) new_set.insert(ctx.renames[id]); set = new_set; } + for (IDSet& set : program->live.live_in) { + IDSet new_set(program->live.memory); + for (uint32_t id : set) + new_set.insert(ctx.renames[id]); + set = new_set; + } } program->allocationID = program->temp_rc.size();