nir: add Continue Construct to nir_loop

The added continue_list corresponds to the SPIR-V
Continue Construct and serves as a converged control-flow
construct and is executed after each continue statement
and before the next iteration of the loop body.

Also adds validation rules for loops with Continue Construct

Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13962>
This commit is contained in:
Daniel Schürmann 2021-12-01 17:34:48 +01:00 committed by Marge Bot
parent e0c6ad1ce5
commit d4b97bf3fa
7 changed files with 126 additions and 19 deletions

View file

@ -647,6 +647,8 @@ nir_loop_create(nir_shader *shader)
body->successors[0] = body;
_mesa_set_add(body->predecessors, body);
exec_list_make_empty(&loop->continue_list);
return loop;
}
@ -1924,11 +1926,17 @@ nir_block_cf_tree_next(nir_block *block)
return nir_if_first_else_block(if_stmt);
assert(block == nir_if_last_else_block(if_stmt));
}
FALLTHROUGH;
case nir_cf_node_loop:
return nir_cf_node_as_block(nir_cf_node_next(parent));
}
case nir_cf_node_loop: {
nir_loop *loop = nir_cf_node_as_loop(parent);
if (block == nir_loop_last_block(loop) &&
nir_loop_has_continue_construct(loop))
return nir_loop_first_continue_block(loop);
return nir_cf_node_as_block(nir_cf_node_next(parent));
}
case nir_cf_node_function:
return NULL;
@ -1962,12 +1970,17 @@ nir_block_cf_tree_prev(nir_block *block)
return nir_if_last_then_block(if_stmt);
assert(block == nir_if_first_then_block(if_stmt));
}
FALLTHROUGH;
case nir_cf_node_loop:
return nir_cf_node_as_block(nir_cf_node_prev(parent));
}
case nir_cf_node_loop: {
nir_loop *loop = nir_cf_node_as_loop(parent);
if (nir_loop_has_continue_construct(loop) &&
block == nir_loop_first_continue_block(loop))
return nir_loop_last_block(loop);
assert(block == nir_loop_first_block(loop));
return nir_cf_node_as_block(nir_cf_node_prev(parent));
}
case nir_cf_node_function:
return NULL;
@ -2018,7 +2031,10 @@ nir_block *nir_cf_node_cf_tree_last(nir_cf_node *node)
case nir_cf_node_loop: {
nir_loop *loop = nir_cf_node_as_loop(node);
return nir_loop_last_block(loop);
if (nir_loop_has_continue_construct(loop))
return nir_loop_last_continue_block(loop);
else
return nir_loop_last_block(loop);
}
case nir_cf_node_block: {

View file

@ -2974,6 +2974,7 @@ typedef struct {
nir_cf_node cf_node;
struct exec_list body; /** < list of nir_cf_node */
struct exec_list continue_list; /** < (optional) list of nir_cf_node */
nir_loop_info *info;
nir_loop_control control;
@ -3209,6 +3210,40 @@ nir_loop_last_block(nir_loop *loop)
return nir_cf_node_as_block(exec_node_data(nir_cf_node, tail, node));
}
static inline bool
nir_loop_has_continue_construct(const nir_loop *loop)
{
return !exec_list_is_empty(&loop->continue_list);
}
static inline nir_block *
nir_loop_first_continue_block(nir_loop *loop)
{
assert(nir_loop_has_continue_construct(loop));
struct exec_node *head = exec_list_get_head(&loop->continue_list);
return nir_cf_node_as_block(exec_node_data(nir_cf_node, head, node));
}
static inline nir_block *
nir_loop_last_continue_block(nir_loop *loop)
{
assert(nir_loop_has_continue_construct(loop));
struct exec_node *tail = exec_list_get_tail(&loop->continue_list);
return nir_cf_node_as_block(exec_node_data(nir_cf_node, tail, node));
}
/**
* Return the target block of a nir_jump_continue statement
*/
static inline nir_block *
nir_loop_continue_target(nir_loop *loop)
{
if (nir_loop_has_continue_construct(loop))
return nir_loop_first_continue_block(loop);
else
return nir_loop_first_block(loop);
}
/**
* Return true if this list of cf_nodes contains a single empty block.
*/

View file

@ -598,6 +598,10 @@ clone_loop(clone_state *state, struct exec_list *cf_list, const nir_loop *loop)
nir_cf_node_insert_end(cf_list, &nloop->cf_node);
clone_cf_list(state, &nloop->body, &loop->body);
if (nir_loop_has_continue_construct(loop)) {
nir_loop_add_continue_construct(nloop);
clone_cf_list(state, &nloop->continue_list, &loop->continue_list);
}
return nloop;
}

View file

@ -288,10 +288,16 @@ block_add_normal_succs(nir_block *block)
} else if (parent->type == nir_cf_node_loop) {
nir_loop *loop = nir_cf_node_as_loop(parent);
nir_block *head_block = nir_loop_first_block(loop);
nir_block *cont_block;
if (block == nir_loop_last_block(loop)) {
cont_block = nir_loop_continue_target(loop);
} else {
assert(block == nir_loop_last_continue_block(loop));
cont_block = nir_loop_first_block(loop);
}
link_blocks(block, head_block, NULL);
nir_insert_phi_undef(head_block, block);
link_blocks(block, cont_block, NULL);
nir_insert_phi_undef(cont_block, block);
} else {
nir_function_impl *impl = nir_cf_node_as_function(parent);
link_blocks(block, impl->end_block, NULL);
@ -482,8 +488,8 @@ nir_handle_add_jump(nir_block *block)
case nir_jump_continue: {
nir_loop *loop = nearest_loop(&block->cf_node);
nir_block *first_block = nir_loop_first_block(loop);
link_blocks(block, first_block, NULL);
nir_block *cont_block = nir_loop_continue_target(loop);
link_blocks(block, cont_block, NULL);
break;
}
@ -665,6 +671,8 @@ cleanup_cf_node(nir_cf_node *node, nir_function_impl *impl)
nir_loop *loop = nir_cf_node_as_loop(node);
foreach_list_typed(nir_cf_node, child, node, &loop->body)
cleanup_cf_node(child, impl);
foreach_list_typed(nir_cf_node, child, node, &loop->continue_list)
cleanup_cf_node(child, impl);
break;
}
case nir_cf_node_function: {
@ -780,6 +788,8 @@ relink_jump_halt_cf_node(nir_cf_node *node, nir_block *end_block)
nir_loop *loop = nir_cf_node_as_loop(node);
foreach_list_typed(nir_cf_node, child, node, &loop->body)
relink_jump_halt_cf_node(child, end_block);
foreach_list_typed(nir_cf_node, child, node, &loop->continue_list)
relink_jump_halt_cf_node(child, end_block);
break;
}

View file

@ -1658,6 +1658,15 @@ print_loop(nir_loop *loop, print_state *state, unsigned tabs)
print_cf_node(node, state, tabs + 1);
}
print_tabs(tabs, fp);
if (nir_loop_has_continue_construct(loop)) {
fprintf(fp, "} continue {\n");
foreach_list_typed(nir_cf_node, node, node, &loop->continue_list) {
print_cf_node(node, state, tabs + 1);
}
print_tabs(tabs, fp);
}
fprintf(fp, "}\n");
}

View file

@ -1892,7 +1892,13 @@ write_loop(write_ctx *ctx, nir_loop *loop)
{
blob_write_uint8(ctx->blob, loop->control);
blob_write_uint8(ctx->blob, loop->divergent);
bool has_continue_construct = nir_loop_has_continue_construct(loop);
blob_write_uint8(ctx->blob, has_continue_construct);
write_cf_list(ctx, &loop->body);
if (has_continue_construct) {
write_cf_list(ctx, &loop->continue_list);
}
}
static void
@ -1904,7 +1910,13 @@ read_loop(read_ctx *ctx, struct exec_list *cf_list)
loop->control = blob_read_uint8(ctx->blob);
loop->divergent = blob_read_uint8(ctx->blob);
bool has_continue_construct = blob_read_uint8(ctx->blob);
read_cf_list(ctx, &loop->body);
if (has_continue_construct) {
nir_loop_add_continue_construct(loop);
read_cf_list(ctx, &loop->continue_list);
}
}
static void

View file

@ -77,6 +77,9 @@ typedef struct {
/* the current loop being visited */
nir_loop *loop;
/* weather the loop continue construct is being visited */
bool in_loop_continue_construct;
/* the parent of the current cf node being visited */
nir_cf_node *parent_node;
@ -1073,6 +1076,7 @@ validate_jump_instr(nir_jump_instr *instr, validate_state *state)
validate_assert(state, block->successors[1] == NULL);
validate_assert(state, instr->target == NULL);
validate_assert(state, instr->else_target == NULL);
validate_assert(state, !state->in_loop_continue_construct);
break;
case nir_jump_break:
@ -1092,12 +1096,13 @@ validate_jump_instr(nir_jump_instr *instr, validate_state *state)
validate_assert(state, state->impl->structured);
validate_assert(state, state->loop != NULL);
if (state->loop) {
nir_block *first = nir_loop_first_block(state->loop);
validate_assert(state, block->successors[0] == first);
nir_block *cont_block = nir_loop_continue_target(state->loop);
validate_assert(state, block->successors[0] == cont_block);
}
validate_assert(state, block->successors[1] == NULL);
validate_assert(state, instr->target == NULL);
validate_assert(state, instr->else_target == NULL);
validate_assert(state, !state->in_loop_continue_construct);
break;
case nir_jump_goto:
@ -1242,6 +1247,7 @@ collect_blocks(struct exec_list *cf_list, validate_state *state)
case nir_cf_node_loop:
collect_blocks(&nir_cf_node_as_loop(node)->body, state);
collect_blocks(&nir_cf_node_as_loop(node)->continue_list, state);
break;
default:
@ -1310,8 +1316,15 @@ validate_block(nir_block *block, validate_state *state)
if (next == NULL) {
switch (state->parent_node->type) {
case nir_cf_node_loop: {
nir_block *first = nir_loop_first_block(state->loop);
validate_assert(state, block->successors[0] == first);
if (block == nir_loop_last_block(state->loop)) {
nir_block *cont = nir_loop_continue_target(state->loop);
validate_assert(state, block->successors[0] == cont);
} else {
validate_assert(state, nir_loop_has_continue_construct(state->loop) &&
block == nir_loop_last_continue_block(state->loop));
nir_block *head = nir_loop_first_block(state->loop);
validate_assert(state, block->successors[0] == head);
}
/* due to the hack for infinite loops, block->successors[1] may
* point to the block after the loop.
*/
@ -1421,14 +1434,21 @@ validate_loop(nir_loop *loop, validate_state *state)
nir_cf_node *old_parent = state->parent_node;
state->parent_node = &loop->cf_node;
nir_loop *old_loop = state->loop;
bool old_continue_construct = state->in_loop_continue_construct;
state->loop = loop;
state->in_loop_continue_construct = false;
foreach_list_typed(nir_cf_node, cf_node, node, &loop->body) {
validate_cf_node(cf_node, state);
}
state->in_loop_continue_construct = true;
foreach_list_typed(nir_cf_node, cf_node, node, &loop->continue_list) {
validate_cf_node(cf_node, state);
}
state->in_loop_continue_construct = false;
state->parent_node = old_parent;
state->loop = old_loop;
state->in_loop_continue_construct = old_continue_construct;
}
static void
@ -1742,6 +1762,7 @@ init_validate_state(validate_state *state)
state->errors = _mesa_pointer_hash_table_create(state->mem_ctx);
state->loop = NULL;
state->in_loop_continue_construct = false;
state->instr = NULL;
state->var = NULL;
}