diff --git a/src/amd/compiler/aco_lower_branches.cpp b/src/amd/compiler/aco_lower_branches.cpp index bf124bdd6fe..a9324b0715c 100644 --- a/src/amd/compiler/aco_lower_branches.cpp +++ b/src/amd/compiler/aco_lower_branches.cpp @@ -38,6 +38,102 @@ remove_linear_successor(branch_ctx& ctx, Block& block, uint32_t succ_index) } } +void +try_remove_simple_block(branch_ctx& ctx, Block& block) +{ + if (!block.instructions.empty() && block.instructions.front()->opcode != aco_opcode::s_branch) + return; + + /* Don't remove the preheader as it might be needed as convergence point + * in order to insert code (e.g. for loop alignment, wait states, etc.). + */ + if (block.kind & block_kind_loop_preheader) + return; + + unsigned succ_idx = block.linear_succs[0]; + Block& succ = ctx.program->blocks[succ_idx]; + for (unsigned pred_idx : block.linear_preds) { + Block& pred = ctx.program->blocks[pred_idx]; + assert(pred.index < block.index); + assert(!pred.instructions.empty() && pred.instructions.back()->isBranch()); + Instruction* branch = pred.instructions.back().get(); + if (branch->opcode == aco_opcode::p_branch) { + /* The predecessor unconditionally jumps to this block. Redirect to successor. */ + pred.linear_succs[0] = succ_idx; + succ.linear_preds.push_back(pred_idx); + } else if (pred.linear_succs[0] == succ_idx || pred.linear_succs[1] == succ_idx) { + /* The predecessor's alternative target is this block's successor. */ + pred.linear_succs[0] = succ_idx; + pred.linear_succs[1] = pred.linear_succs.back(); /* In case of discard */ + pred.linear_succs.pop_back(); + branch->opcode = aco_opcode::p_branch; + } else if (pred.linear_succs[1] == block.index) { + /* The predecessor jumps to this block. Redirect to successor. */ + pred.linear_succs[1] = succ_idx; + succ.linear_preds.push_back(pred_idx); + } else { + /* This block is the fall-through target of the predecessor. */ + if (block.instructions.empty()) { + /* If this block is empty, just fall-through to the successor. */ + pred.linear_succs[0] = succ_idx; + succ.linear_preds.push_back(pred_idx); + continue; + } + + /* Otherwise, check if there is a fall-through path for the jump target. */ + if (block.index >= pred.linear_succs[1]) + return; + for (unsigned j = block.index + 1; j < pred.linear_succs[1]; j++) { + if (!ctx.program->blocks[j].instructions.empty()) + return; + } + pred.linear_succs[0] = pred.linear_succs[1]; + pred.linear_succs[1] = succ_idx; + succ.linear_preds.push_back(pred_idx); + + /* Invert the condition. This branch now falls through to its original target. + * However, we don't update the fall-through target since this instruction + * gets lowered in the next step, anyway. + */ + if (branch->opcode == aco_opcode::p_cbranch_nz) + branch->opcode = aco_opcode::p_cbranch_z; + else + branch->opcode = aco_opcode::p_cbranch_nz; + } + + /* Update the branch target. */ + branch->branch().target[0] = succ_idx; + } + + /* If this block is part of the logical CFG, also connect pre- and successors. */ + if (!block.logical_succs.empty()) { + assert(block.logical_succs.size() == 1); + unsigned logical_succ_idx = block.logical_succs[0]; + Block& logical_succ = ctx.program->blocks[logical_succ_idx]; + ASSERTED auto it = std::remove(logical_succ.logical_preds.begin(), + logical_succ.logical_preds.end(), block.index); + assert(std::next(it) == logical_succ.logical_preds.end()); + logical_succ.logical_preds.pop_back(); + for (unsigned pred_idx : block.logical_preds) { + Block& pred = ctx.program->blocks[pred_idx]; + std::replace(pred.logical_succs.begin(), pred.logical_succs.end(), block.index, + logical_succ_idx); + + if (pred.logical_succs.size() == 2 && pred.logical_succs[0] == pred.logical_succs[1]) + pred.logical_succs.pop_back(); /* This should have been optimized in NIR! */ + else + logical_succ.logical_preds.push_back(pred_idx); + } + + block.logical_succs.clear(); + block.logical_preds.clear(); + } + + remove_linear_successor(ctx, block, succ_idx); + block.linear_preds.clear(); + block.instructions.clear(); +} + void eliminate_useless_exec_writes_in_block(branch_ctx& ctx, Block& block) { @@ -115,7 +211,9 @@ can_remove_branch(branch_ctx& ctx, Block& block, Pseudo_branch_instruction* bran branch->operands[0].physReg() == exec); if (branch->never_taken) { - assert(!uniform_branch); + assert(!uniform_branch || std::all_of(std::next(ctx.program->blocks.begin(), block.index + 1), + std::next(ctx.program->blocks.begin(), target), + [](Block& b) { return b.instructions.empty(); })); return true; } @@ -259,6 +357,9 @@ lower_branches(Program* program) Block& block = program->blocks[i]; lower_branch_instruction(ctx, block); eliminate_useless_exec_writes_in_block(ctx, block); + + if (block.linear_succs.size() == 1) + try_remove_simple_block(ctx, block); } }