diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index 6fa1dfb5133..3199af67a4f 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -647,6 +647,8 @@ nir_loop_create(nir_shader *shader) body->successors[0] = body; _mesa_set_add(body->predecessors, body); + exec_list_make_empty(&loop->continue_list); + return loop; } @@ -1924,11 +1926,17 @@ nir_block_cf_tree_next(nir_block *block) return nir_if_first_else_block(if_stmt); assert(block == nir_if_last_else_block(if_stmt)); - } - FALLTHROUGH; - - case nir_cf_node_loop: return nir_cf_node_as_block(nir_cf_node_next(parent)); + } + + case nir_cf_node_loop: { + nir_loop *loop = nir_cf_node_as_loop(parent); + if (block == nir_loop_last_block(loop) && + nir_loop_has_continue_construct(loop)) + return nir_loop_first_continue_block(loop); + + return nir_cf_node_as_block(nir_cf_node_next(parent)); + } case nir_cf_node_function: return NULL; @@ -1962,12 +1970,17 @@ nir_block_cf_tree_prev(nir_block *block) return nir_if_last_then_block(if_stmt); assert(block == nir_if_first_then_block(if_stmt)); - } - FALLTHROUGH; - - case nir_cf_node_loop: return nir_cf_node_as_block(nir_cf_node_prev(parent)); + } + case nir_cf_node_loop: { + nir_loop *loop = nir_cf_node_as_loop(parent); + if (nir_loop_has_continue_construct(loop) && + block == nir_loop_first_continue_block(loop)) + return nir_loop_last_block(loop); + assert(block == nir_loop_first_block(loop)); + return nir_cf_node_as_block(nir_cf_node_prev(parent)); + } case nir_cf_node_function: return NULL; @@ -2018,7 +2031,10 @@ nir_block *nir_cf_node_cf_tree_last(nir_cf_node *node) case nir_cf_node_loop: { nir_loop *loop = nir_cf_node_as_loop(node); - return nir_loop_last_block(loop); + if (nir_loop_has_continue_construct(loop)) + return nir_loop_last_continue_block(loop); + else + return nir_loop_last_block(loop); } case nir_cf_node_block: { diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 344d604352c..18ede24f8c2 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -2974,6 +2974,7 @@ typedef struct { nir_cf_node cf_node; struct exec_list body; /** < list of nir_cf_node */ + struct exec_list continue_list; /** < (optional) list of nir_cf_node */ nir_loop_info *info; nir_loop_control control; @@ -3209,6 +3210,40 @@ nir_loop_last_block(nir_loop *loop) return nir_cf_node_as_block(exec_node_data(nir_cf_node, tail, node)); } +static inline bool +nir_loop_has_continue_construct(const nir_loop *loop) +{ + return !exec_list_is_empty(&loop->continue_list); +} + +static inline nir_block * +nir_loop_first_continue_block(nir_loop *loop) +{ + assert(nir_loop_has_continue_construct(loop)); + struct exec_node *head = exec_list_get_head(&loop->continue_list); + return nir_cf_node_as_block(exec_node_data(nir_cf_node, head, node)); +} + +static inline nir_block * +nir_loop_last_continue_block(nir_loop *loop) +{ + assert(nir_loop_has_continue_construct(loop)); + struct exec_node *tail = exec_list_get_tail(&loop->continue_list); + return nir_cf_node_as_block(exec_node_data(nir_cf_node, tail, node)); +} + +/** + * Return the target block of a nir_jump_continue statement + */ +static inline nir_block * +nir_loop_continue_target(nir_loop *loop) +{ + if (nir_loop_has_continue_construct(loop)) + return nir_loop_first_continue_block(loop); + else + return nir_loop_first_block(loop); +} + /** * Return true if this list of cf_nodes contains a single empty block. */ diff --git a/src/compiler/nir/nir_clone.c b/src/compiler/nir/nir_clone.c index 700f1b4ddb7..52682a1a185 100644 --- a/src/compiler/nir/nir_clone.c +++ b/src/compiler/nir/nir_clone.c @@ -598,6 +598,10 @@ clone_loop(clone_state *state, struct exec_list *cf_list, const nir_loop *loop) nir_cf_node_insert_end(cf_list, &nloop->cf_node); clone_cf_list(state, &nloop->body, &loop->body); + if (nir_loop_has_continue_construct(loop)) { + nir_loop_add_continue_construct(nloop); + clone_cf_list(state, &nloop->continue_list, &loop->continue_list); + } return nloop; } diff --git a/src/compiler/nir/nir_control_flow.c b/src/compiler/nir/nir_control_flow.c index 4973c60de5c..43d02f2e176 100644 --- a/src/compiler/nir/nir_control_flow.c +++ b/src/compiler/nir/nir_control_flow.c @@ -288,10 +288,16 @@ block_add_normal_succs(nir_block *block) } else if (parent->type == nir_cf_node_loop) { nir_loop *loop = nir_cf_node_as_loop(parent); - nir_block *head_block = nir_loop_first_block(loop); + nir_block *cont_block; + if (block == nir_loop_last_block(loop)) { + cont_block = nir_loop_continue_target(loop); + } else { + assert(block == nir_loop_last_continue_block(loop)); + cont_block = nir_loop_first_block(loop); + } - link_blocks(block, head_block, NULL); - nir_insert_phi_undef(head_block, block); + link_blocks(block, cont_block, NULL); + nir_insert_phi_undef(cont_block, block); } else { nir_function_impl *impl = nir_cf_node_as_function(parent); link_blocks(block, impl->end_block, NULL); @@ -482,8 +488,8 @@ nir_handle_add_jump(nir_block *block) case nir_jump_continue: { nir_loop *loop = nearest_loop(&block->cf_node); - nir_block *first_block = nir_loop_first_block(loop); - link_blocks(block, first_block, NULL); + nir_block *cont_block = nir_loop_continue_target(loop); + link_blocks(block, cont_block, NULL); break; } @@ -665,6 +671,8 @@ cleanup_cf_node(nir_cf_node *node, nir_function_impl *impl) nir_loop *loop = nir_cf_node_as_loop(node); foreach_list_typed(nir_cf_node, child, node, &loop->body) cleanup_cf_node(child, impl); + foreach_list_typed(nir_cf_node, child, node, &loop->continue_list) + cleanup_cf_node(child, impl); break; } case nir_cf_node_function: { @@ -780,6 +788,8 @@ relink_jump_halt_cf_node(nir_cf_node *node, nir_block *end_block) nir_loop *loop = nir_cf_node_as_loop(node); foreach_list_typed(nir_cf_node, child, node, &loop->body) relink_jump_halt_cf_node(child, end_block); + foreach_list_typed(nir_cf_node, child, node, &loop->continue_list) + relink_jump_halt_cf_node(child, end_block); break; } diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index 25868c8f779..f63401be27d 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -1658,6 +1658,15 @@ print_loop(nir_loop *loop, print_state *state, unsigned tabs) print_cf_node(node, state, tabs + 1); } print_tabs(tabs, fp); + + if (nir_loop_has_continue_construct(loop)) { + fprintf(fp, "} continue {\n"); + foreach_list_typed(nir_cf_node, node, node, &loop->continue_list) { + print_cf_node(node, state, tabs + 1); + } + print_tabs(tabs, fp); + } + fprintf(fp, "}\n"); } diff --git a/src/compiler/nir/nir_serialize.c b/src/compiler/nir/nir_serialize.c index 5ee571a6898..ae490e07348 100644 --- a/src/compiler/nir/nir_serialize.c +++ b/src/compiler/nir/nir_serialize.c @@ -1892,7 +1892,13 @@ write_loop(write_ctx *ctx, nir_loop *loop) { blob_write_uint8(ctx->blob, loop->control); blob_write_uint8(ctx->blob, loop->divergent); + bool has_continue_construct = nir_loop_has_continue_construct(loop); + blob_write_uint8(ctx->blob, has_continue_construct); + write_cf_list(ctx, &loop->body); + if (has_continue_construct) { + write_cf_list(ctx, &loop->continue_list); + } } static void @@ -1904,7 +1910,13 @@ read_loop(read_ctx *ctx, struct exec_list *cf_list) loop->control = blob_read_uint8(ctx->blob); loop->divergent = blob_read_uint8(ctx->blob); + bool has_continue_construct = blob_read_uint8(ctx->blob); + read_cf_list(ctx, &loop->body); + if (has_continue_construct) { + nir_loop_add_continue_construct(loop); + read_cf_list(ctx, &loop->continue_list); + } } static void diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c index 5e778109a73..82241c9f313 100644 --- a/src/compiler/nir/nir_validate.c +++ b/src/compiler/nir/nir_validate.c @@ -77,6 +77,9 @@ typedef struct { /* the current loop being visited */ nir_loop *loop; + /* weather the loop continue construct is being visited */ + bool in_loop_continue_construct; + /* the parent of the current cf node being visited */ nir_cf_node *parent_node; @@ -1073,6 +1076,7 @@ validate_jump_instr(nir_jump_instr *instr, validate_state *state) validate_assert(state, block->successors[1] == NULL); validate_assert(state, instr->target == NULL); validate_assert(state, instr->else_target == NULL); + validate_assert(state, !state->in_loop_continue_construct); break; case nir_jump_break: @@ -1092,12 +1096,13 @@ validate_jump_instr(nir_jump_instr *instr, validate_state *state) validate_assert(state, state->impl->structured); validate_assert(state, state->loop != NULL); if (state->loop) { - nir_block *first = nir_loop_first_block(state->loop); - validate_assert(state, block->successors[0] == first); + nir_block *cont_block = nir_loop_continue_target(state->loop); + validate_assert(state, block->successors[0] == cont_block); } validate_assert(state, block->successors[1] == NULL); validate_assert(state, instr->target == NULL); validate_assert(state, instr->else_target == NULL); + validate_assert(state, !state->in_loop_continue_construct); break; case nir_jump_goto: @@ -1242,6 +1247,7 @@ collect_blocks(struct exec_list *cf_list, validate_state *state) case nir_cf_node_loop: collect_blocks(&nir_cf_node_as_loop(node)->body, state); + collect_blocks(&nir_cf_node_as_loop(node)->continue_list, state); break; default: @@ -1310,8 +1316,15 @@ validate_block(nir_block *block, validate_state *state) if (next == NULL) { switch (state->parent_node->type) { case nir_cf_node_loop: { - nir_block *first = nir_loop_first_block(state->loop); - validate_assert(state, block->successors[0] == first); + if (block == nir_loop_last_block(state->loop)) { + nir_block *cont = nir_loop_continue_target(state->loop); + validate_assert(state, block->successors[0] == cont); + } else { + validate_assert(state, nir_loop_has_continue_construct(state->loop) && + block == nir_loop_last_continue_block(state->loop)); + nir_block *head = nir_loop_first_block(state->loop); + validate_assert(state, block->successors[0] == head); + } /* due to the hack for infinite loops, block->successors[1] may * point to the block after the loop. */ @@ -1421,14 +1434,21 @@ validate_loop(nir_loop *loop, validate_state *state) nir_cf_node *old_parent = state->parent_node; state->parent_node = &loop->cf_node; nir_loop *old_loop = state->loop; + bool old_continue_construct = state->in_loop_continue_construct; state->loop = loop; + state->in_loop_continue_construct = false; foreach_list_typed(nir_cf_node, cf_node, node, &loop->body) { validate_cf_node(cf_node, state); } - + state->in_loop_continue_construct = true; + foreach_list_typed(nir_cf_node, cf_node, node, &loop->continue_list) { + validate_cf_node(cf_node, state); + } + state->in_loop_continue_construct = false; state->parent_node = old_parent; state->loop = old_loop; + state->in_loop_continue_construct = old_continue_construct; } static void @@ -1742,6 +1762,7 @@ init_validate_state(validate_state *state) state->errors = _mesa_pointer_hash_table_create(state->mem_ctx); state->loop = NULL; + state->in_loop_continue_construct = false; state->instr = NULL; state->var = NULL; }