diff --git a/src/compiler/nir/nir_lower_continue_constructs.c b/src/compiler/nir/nir_lower_continue_constructs.c index ce0a74ea1ce..bb0e8ff588d 100644 --- a/src/compiler/nir/nir_lower_continue_constructs.c +++ b/src/compiler/nir/nir_lower_continue_constructs.c @@ -26,12 +26,217 @@ #include "nir_builder.h" #include "nir_control_flow.h" +/* NIR pass to lower loop continue constructs. + * + * NIR loops are maintained in canonical form with these properties: + * - a pre-header: the only predecessor of the loop header + * - a dedicated exit node: dominated by the loop-header + * - a single back-edge to the loop header: the trivial continue + * + * If the loop has a continue construct, the trivial continue is the + * back-edge from the last block of the continue construct to the loop + * header. Otherwise, it is the back-edge from the last block of the + * loop body to the loop header. + * + * In order to lower the continue construct of a loop, all continue + * statements are being removed by either + * - moving the following code to the other side of a branch or + * - guarding following code by inserted IF-statements + * + * Afterwards, the continue construct is inlined before the trivial + * back-edge. + * + */ + +struct loop_simplify_state { + nir_builder *b; + nir_def *continue_flag; + struct exec_list *cf_list; +}; + +static bool +block_ends_in_continue(nir_block *block) +{ + if (nir_block_ends_in_jump(block)) { + nir_jump_instr *jump = nir_instr_as_jump(nir_block_last_instr(block)); + return jump->type == nir_jump_continue; + } + + return false; +} + +static bool +lower_continues_in_cf_list(struct exec_list *cf_list, + struct loop_simplify_state *state); + +static bool +lower_continue(nir_block *block, struct loop_simplify_state *state) +{ + if (!block_ends_in_continue(block)) + return false; + + assert(nir_cf_node_is_last(&block->cf_node)); + + /* Remove the continue instruction and set the predicate to 'true'. */ + state->b->cursor = nir_instr_remove(nir_block_last_instr(block)); + nir_store_reg(state->b, nir_imm_true(state->b), state->continue_flag); + + return true; +} + +static bool +lower_continues_in_if(nir_if *nif, struct loop_simplify_state *state) +{ + nir_block *then_block = nir_if_last_then_block(nif); + nir_block *else_block = nir_if_last_else_block(nif); + bool then_jumps = nir_block_ends_in_jump(then_block); + bool else_jumps = nir_block_ends_in_jump(else_block); + + bool progress = false; + progress |= lower_continue(then_block, state); + progress |= lower_continue(else_block, state); + + nir_block *next_block = nir_cf_node_cf_tree_next(&nif->cf_node); + bool is_empty_block = nir_cf_node_is_last(&next_block->cf_node) && + exec_list_is_empty(&next_block->instr_list); + + /* If a branch leg ends in a jump, we lower already any continue statements, + * so that we know if we have to move the following blocks to the other side. + */ + if (then_jumps) + progress |= lower_continues_in_cf_list(&nif->then_list, state); + if (else_jumps) + progress |= lower_continues_in_cf_list(&nif->else_list, state); + + if (!is_empty_block && progress) { + /* If one side ends in a jump and has a continue statement, move the + * following code to the other side. This is necessary to maintain SSA + * dominance, because the jump gets guarded or removed. + * Doing this based on (then_jumps || else_jumps) would also be a valid + * transform. However, it's unnecessary if we don't have any continues + * to lower and increases control-flow depth. + */ + assert(then_jumps || else_jumps); + /* Lower away potential single-src phis, if there was any. */ + nir_lower_phis_to_regs_block(nir_cf_node_cf_tree_next(&nif->cf_node), true); + nir_cf_list list; + nir_cf_extract(&list, nir_after_cf_node(&nif->cf_node), + nir_after_cf_list(state->cf_list)); + + if (then_jumps && else_jumps) { + /* Both branches jump, just delete instructions following the IF. */ + nir_cf_delete(&list); + } else if (then_jumps) { + nir_cf_reinsert(&list, nir_after_cf_list(&nif->else_list)); + } else { + nir_cf_reinsert(&list, nir_after_cf_list(&nif->then_list)); + } + + /* The successor is now empty. No need to predicate following blocks. */ + is_empty_block = true; + } + + /* Recursively lower any continue statements for the sides that didn't jump. + * We do that here so that this also lowers the code which was after the IF + * which we just moved inside. + */ + if (!then_jumps) + progress |= lower_continues_in_cf_list(&nif->then_list, state); + if (!else_jumps) + progress |= lower_continues_in_cf_list(&nif->else_list, state); + + if (!is_empty_block && progress) { + /* Predicate following blocks. */ + nir_cf_list list; + nir_cf_extract(&list, nir_after_cf_node_and_phis(&nif->cf_node), + nir_after_cf_list(state->cf_list)); + + state->b->cursor = nir_after_cf_node_and_phis(&nif->cf_node); + nir_if *if_stmt = nir_push_if(state->b, nir_load_reg(state->b, state->continue_flag)); + + assert(!exec_list_is_empty(&list.list)); + nir_cf_reinsert(&list, nir_before_cf_list(&if_stmt->else_list)); + nir_pop_if(state->b, NULL); + } + + return progress; +} + +static bool +lower_continues_in_cf_list(struct exec_list *cf_list, + struct loop_simplify_state *state) +{ + bool progress = false; + + struct exec_list *parent_list = state->cf_list; + state->cf_list = cf_list; + + /* We iterate over the list backwards because any given lower call may + * take everything following the given CF node and predicate it. In + * order to avoid recursion/iteration problems, we want everything after + * a given node to already be lowered before this happens. + */ + foreach_list_typed_reverse_safe(nir_cf_node, node, node, cf_list) { + switch (node->type) { + case nir_cf_node_if: + if (lower_continues_in_if(nir_cf_node_as_if(node), state)) + progress = true; + break; + + case nir_cf_node_block: + case nir_cf_node_loop: + break; + + default: + UNREACHABLE("Invalid inner CF node type"); + } + } + + state->cf_list = parent_list; + + return progress; +} + +static void +simplify_loop(nir_loop *loop) +{ + nir_block *cont = nir_loop_first_continue_block(loop); + nir_block *last = nir_loop_last_block(loop); + + /* Remove trivial continue statement. */ + if (block_ends_in_continue(last)) + nir_instr_remove(nir_block_last_instr(last)); + + /* If the loop has only the trivial continue, there is nothing to do. */ + if (!nir_block_ends_in_jump(last) && cont->predecessors.entries == 1) + return; + + struct loop_simplify_state state; + nir_builder b = nir_builder_at(nir_before_block_after_phis(nir_loop_first_block(loop))); + state.b = &b; + + /* Initialize the variable to False. */ + state.continue_flag = nir_decl_reg(&b, 1, 1, 0); + nir_store_reg(&b, nir_imm_false(&b), state.continue_flag); + + lower_continues_in_cf_list(&loop->body, &state); + + return; +} + static bool lower_loop_continue_block(nir_builder *b, nir_loop *loop, bool *repair_ssa) { if (!nir_loop_has_continue_construct(loop)) return false; + /* Lower loop header and continue-phis to regs as we are going to move the predecessors. */ + nir_lower_phis_to_regs_block(nir_loop_first_block(loop), true); + nir_lower_phis_to_regs_block(nir_loop_first_continue_block(loop), true); + + /* Simplify the loop in order to ensure that it has at most one back-edge. */ + simplify_loop(loop); + nir_block *header = nir_loop_first_block(loop); nir_block *cont = nir_loop_first_continue_block(loop); @@ -49,8 +254,6 @@ lower_loop_continue_block(nir_builder *b, nir_loop *loop, bool *repair_ssa) break; } - nir_lower_phis_to_regs_block(header, false); - if (num_continue == 0) { /* this loop doesn't continue at all. delete the continue construct */ nir_cf_list extracted; @@ -118,8 +321,17 @@ visit_cf_list(nir_builder *b, struct exec_list *list, bool *repair_ssa) } case nir_cf_node_loop: { nir_loop *loop = nir_cf_node_as_loop(node); + /* By first lowering inner loops, we ensure that we don't encounter + * any continue statements which don't belong to the current loop. + */ progress |= visit_cf_list(b, &loop->body, repair_ssa); + + /* If we lower continue constructs after inlining functions, they + * might contain nested loops. + */ progress |= visit_cf_list(b, &loop->continue_list, repair_ssa); + + /* Lower continue construct. */ progress |= lower_loop_continue_block(b, loop, repair_ssa); break; }