From 64fda091de9132b3694af58f2e9a3427d3419a40 Mon Sep 17 00:00:00 2001 From: Friedrich Vock Date: Thu, 8 Jun 2023 18:56:43 +0200 Subject: [PATCH] aco: Lower divergent bool phis iteratively MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoids stack overflows with really large programs. No fossil-db changes. Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/8760 Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/8701 Reviewed-by: Daniel Schürmann Cc: mesa-stable Part-of: --- src/amd/compiler/aco_lower_phis.cpp | 178 ++++++++++++++-------------- 1 file changed, 90 insertions(+), 88 deletions(-) diff --git a/src/amd/compiler/aco_lower_phis.cpp b/src/amd/compiler/aco_lower_phis.cpp index b9b44e8fb86..bd803c68b5f 100644 --- a/src/amd/compiler/aco_lower_phis.cpp +++ b/src/amd/compiler/aco_lower_phis.cpp @@ -52,73 +52,72 @@ struct ssa_state { std::vector outputs; /* the output per block */ }; -Operand -get_ssa(Program* program, unsigned block_idx, ssa_state* state, bool input) +Operand get_output(Program* program, unsigned block_idx, ssa_state* state); + +void +init_outputs(Program* program, ssa_state* state, unsigned start, unsigned end) { - if (!input) { - if (state->visited[block_idx]) - return state->outputs[block_idx]; - - /* otherwise, output == input */ - Operand output = get_ssa(program, block_idx, state, true); - state->visited[block_idx] = true; - state->outputs[block_idx] = output; - return output; + for (unsigned i = start; i < end; ++i) { + if (state->visited[i]) + continue; + state->outputs[i] = get_output(program, i, state); + state->visited[i] = true; } +} + +Operand +get_output(Program* program, unsigned block_idx, ssa_state* state) +{ + Block& block = program->blocks[block_idx]; - /* retrieve the Operand by checking the predecessors */ if (state->any_pred_defined[block_idx] == pred_defined::undef) return Operand(program->lane_mask); - Block& block = program->blocks[block_idx]; - size_t pred = block.linear_preds.size(); - Operand op; - if (block.loop_nest_depth < state->loop_nest_depth) { + if (block.loop_nest_depth < state->loop_nest_depth) /* loop-carried value for loop exit phis */ - op = Operand::zero(program->lane_mask.bytes()); - } else if (block.loop_nest_depth > state->loop_nest_depth || pred == 1 || - block.kind & block_kind_loop_exit) { - op = get_ssa(program, block.linear_preds[0], state, false); + return Operand::zero(program->lane_mask.bytes()); + + size_t num_preds = block.linear_preds.size(); + + if (block.loop_nest_depth > state->loop_nest_depth || num_preds == 1 || + block.kind & block_kind_loop_exit) + return state->outputs[block.linear_preds[0]]; + + Operand output; + + /* Loop headers can contain back edges, in which case the predecessor + * outputs aren't yet determined because the predecessor is after the block. + * The predecessor outputs also depend on the output of the loop header, + * so allocate a temporary that will store this block's output and use that + * to calculate the predecessor block output. In this case, we always emit a phi + * to ensure the allocated temporary is defined. */ + if (block.kind & block_kind_loop_header) { + unsigned start_idx = block_idx + 1; + unsigned end_idx = block.linear_preds.back() + 1; + + state->outputs[block_idx] = Operand(Temp(program->allocateTmp(program->lane_mask))); + init_outputs(program, state, start_idx, end_idx); + output = state->outputs[block_idx]; + } else if (std::all_of(block.linear_preds.begin() + 1, block.linear_preds.end(), + [&](unsigned pred) { + return state->outputs[pred] == state->outputs[block.linear_preds[0]]; + })) { + return state->outputs[block.linear_preds[0]]; } else { - assert(pred > 1); - bool previously_visited = state->visited[block_idx]; - /* potential recursion: anchor at loop header */ - if (block.kind & block_kind_loop_header) { - assert(!previously_visited); - previously_visited = true; - state->visited[block_idx] = true; - state->outputs[block_idx] = Operand(Temp(program->allocateTmp(program->lane_mask))); - } - - /* collect predecessor output operands */ - std::vector ops(pred); - for (unsigned i = 0; i < pred; i++) - ops[i] = get_ssa(program, block.linear_preds[i], state, false); - - /* check triviality */ - if (std::all_of(ops.begin() + 1, ops.end(), [&](Operand same) { return same == ops[0]; })) - return ops[0]; - - /* Return if this was handled in a recursive call by a loop header phi */ - if (!previously_visited && state->visited[block_idx]) - return state->outputs[block_idx]; - - if (block.kind & block_kind_loop_header) - op = state->outputs[block_idx]; - else - op = Operand(Temp(program->allocateTmp(program->lane_mask))); - - /* create phi */ - aco_ptr phi{ - create_instruction(aco_opcode::p_linear_phi, Format::PSEUDO, pred, 1)}; - for (unsigned i = 0; i < pred; i++) - phi->operands[i] = ops[i]; - phi->definitions[0] = Definition(op.getTemp()); - block.instructions.emplace(block.instructions.begin(), std::move(phi)); + output = Operand(Temp(program->allocateTmp(program->lane_mask))); } - assert(op.size() == program->lane_mask.size()); - return op; + /* create phi */ + aco_ptr phi{create_instruction( + aco_opcode::p_linear_phi, Format::PSEUDO, num_preds, 1)}; + for (unsigned i = 0; i < num_preds; i++) + phi->operands[i] = state->outputs[block.linear_preds[i]]; + phi->definitions[0] = Definition(output.getTemp()); + block.instructions.emplace(block.instructions.begin(), std::move(phi)); + + assert(output.size() == program->lane_mask.size()); + + return output; } void @@ -141,7 +140,7 @@ build_merge_code(Program* program, ssa_state* state, Block* block, Operand cur) { unsigned block_idx = block->index; Definition dst = Definition(state->outputs[block_idx].getTemp()); - Operand prev = get_ssa(program, block_idx, state, true); + Operand prev = get_output(program, block_idx, state); if (cur.isUndefined()) cur = Operand::zero(program->lane_mask.bytes()); @@ -239,9 +238,20 @@ build_const_else_merge_code(Program* program, Block& invert_block, aco_ptr& phi) +init_state(Program* program, Block* block, ssa_state* state, aco_ptr& phi) { + Builder bld(program); + + /* do this here to avoid resizing in case of no boolean phis */ + state->visited.resize(program->blocks.size()); + state->outputs.resize(program->blocks.size()); + state->any_pred_defined.resize(program->blocks.size()); + state->loop_nest_depth = block->loop_nest_depth; + if (block->kind & block_kind_loop_exit) + state->loop_nest_depth += 1; + std::fill(state->visited.begin(), state->visited.end(), false); std::fill(state->any_pred_defined.begin(), state->any_pred_defined.end(), pred_defined::undef); + for (unsigned i = 0; i < block->logical_preds.size(); i++) { if (phi->operands[i].isUndefined()) continue; @@ -255,14 +265,14 @@ init_any_pred_defined(Program* program, ssa_state* state, Block* block, aco_ptr< unsigned start = block->logical_preds[0]; unsigned end = block->index; - /* for loop exit phis, start at the loop header */ + /* for loop exit phis, start at the loop pre-header */ if (block->kind & block_kind_loop_exit) { - while (program->blocks[start - 1].loop_nest_depth >= state->loop_nest_depth) + while (program->blocks[start].loop_nest_depth >= state->loop_nest_depth) start--; /* If the loop-header has a back-edge, we need to insert a phi. * This will contain a defined value */ - if (program->blocks[start].linear_preds.size() > 1) - state->any_pred_defined[start] = pred_defined::temp; + if (program->blocks[start + 1].linear_preds.size() > 1) + state->any_pred_defined[start + 1] = pred_defined::temp; } /* for loop header phis, end at the loop exit */ if (block->kind & block_kind_loop_header) { @@ -277,10 +287,10 @@ init_any_pred_defined(Program* program, ssa_state* state, Block* block, aco_ptr< // TODO: find more occasions where pred_defined::zero is beneficial (e.g. with 2+ temp merges) if (block->kind & block_kind_loop_exit) { /* zero the loop-carried variable */ - if (program->blocks[start].linear_preds.size() > 1) { - state->any_pred_defined[start] |= pred_defined::zero; + if (program->blocks[start + 1].linear_preds.size() > 1) { + state->any_pred_defined[start + 1] |= pred_defined::zero; // TODO: emit this zero explicitly - state->any_pred_defined[start - 1] = pred_defined::const_0; + state->any_pred_defined[start] = pred_defined::const_0; } } @@ -292,14 +302,24 @@ init_any_pred_defined(Program* program, ssa_state* state, Block* block, aco_ptr< } state->any_pred_defined[block->index] = pred_defined::undef; + + for (unsigned i = 0; i < phi->operands.size(); i++) { + unsigned pred = block->logical_preds[i]; + if (state->any_pred_defined[pred] != pred_defined::undef) + state->outputs[pred] = Operand(bld.tmp(bld.lm)); + else + state->outputs[pred] = phi->operands[i]; + assert(state->outputs[pred].size() == bld.lm.size()); + state->visited[pred] = true; + } + + init_outputs(program, state, start, end); } void lower_divergent_bool_phi(Program* program, ssa_state* state, Block* block, aco_ptr& phi) { - Builder bld(program); - if (!state->checked_preds_for_uniform) { state->all_preds_uniform = !(block->kind & block_kind_merge) && block->linear_preds.size() == block->logical_preds.size(); @@ -320,25 +340,7 @@ lower_divergent_bool_phi(Program* program, ssa_state* state, Block* block, return; } - /* do this here to avoid resizing in case of no boolean phis */ - state->visited.resize(program->blocks.size()); - state->outputs.resize(program->blocks.size()); - state->any_pred_defined.resize(program->blocks.size()); - state->loop_nest_depth = block->loop_nest_depth; - if (block->kind & block_kind_loop_exit) - state->loop_nest_depth += 1; - std::fill(state->visited.begin(), state->visited.end(), false); - init_any_pred_defined(program, state, block, phi); - - for (unsigned i = 0; i < phi->operands.size(); i++) { - unsigned pred = block->logical_preds[i]; - if (state->any_pred_defined[pred] != pred_defined::undef) - state->outputs[pred] = Operand(bld.tmp(bld.lm)); - else - state->outputs[pred] = phi->operands[i]; - assert(state->outputs[pred].size() == bld.lm.size()); - state->visited[pred] = true; - } + init_state(program, block, state, phi); for (unsigned i = 0; i < phi->operands.size(); i++) build_merge_code(program, state, &program->blocks[block->logical_preds[i]], phi->operands[i]); @@ -355,7 +357,7 @@ lower_divergent_bool_phi(Program* program, ssa_state* state, Block* block, assert(phi->operands.size() == num_preds); for (unsigned i = 0; i < num_preds; i++) - phi->operands[i] = get_ssa(program, block->linear_preds[i], state, false); + phi->operands[i] = state->outputs[block->linear_preds[i]]; return; }