nir/spirv: Handle OpBranchConditional

We do control-flow handling as a two-step process.  The first step is to
walk the instructions list and record various information about blocks and
functions.  This is where the acutal nir_function_overload objects get
created.  We also record the start/stop instruction for each block.  Then
a second pass walks over each of the functions and over the blocks in each
function in a way that's NIR-friendly and actually parses the instructions.
This commit is contained in:
Jason Ekstrand 2015-05-04 10:23:09 -07:00
parent d216dcee94
commit b7904b8281

View file

@ -44,6 +44,19 @@ enum vtn_value_type {
vtn_value_type_ssa,
};
struct vtn_block {
const uint32_t *label;
const uint32_t *branch;
nir_block *block;
};
struct vtn_function {
struct exec_node node;
nir_function_overload *overload;
struct vtn_block *start_block;
};
struct vtn_value {
enum vtn_value_type value_type;
const char *name;
@ -54,8 +67,8 @@ struct vtn_value {
const struct glsl_type *type;
nir_constant *constant;
nir_deref_var *deref;
nir_function_impl *impl;
nir_block *block;
struct vtn_function *func;
struct vtn_block *block;
nir_ssa_def *ssa;
};
};
@ -71,12 +84,17 @@ struct vtn_builder {
nir_shader *shader;
nir_function_impl *impl;
struct exec_list *cf_list;
struct vtn_block *block;
struct vtn_block *merge_block;
unsigned value_id_bound;
struct vtn_value *values;
SpvExecutionModel execution_model;
struct vtn_value *entry_point;
struct vtn_function *func;
struct exec_list functions;
};
static struct vtn_value *
@ -672,60 +690,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
}
static void
vtn_handle_functions(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
switch (opcode) {
case SpvOpFunction: {
assert(b->impl == NULL);
const struct glsl_type *result_type =
vtn_value(b, w[1], vtn_value_type_type)->type;
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
const struct glsl_type *func_type =
vtn_value(b, w[4], vtn_value_type_type)->type;
assert(glsl_get_function_return_type(func_type) == result_type);
nir_function *func =
nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
nir_function_overload *overload = nir_function_overload_create(func);
overload->num_params = glsl_get_length(func_type);
overload->params = ralloc_array(overload, nir_parameter,
overload->num_params);
for (unsigned i = 0; i < overload->num_params; i++) {
const struct glsl_function_param *param =
glsl_get_function_param(func_type, i);
overload->params[i].type = param->type;
if (param->in) {
if (param->out) {
overload->params[i].param_type = nir_parameter_inout;
} else {
overload->params[i].param_type = nir_parameter_in;
}
} else {
if (param->out) {
overload->params[i].param_type = nir_parameter_out;
} else {
assert(!"Parameter is neither in nor out");
}
}
}
val->impl = b->impl = nir_function_impl_create(overload);
b->cf_list = &b->impl->body;
break;
}
case SpvOpFunctionEnd:
b->impl = NULL;
break;
case SpvOpFunctionParameter:
case SpvOpFunctionCall:
default:
unreachable("Unhandled opcode");
}
unreachable("Unhandled opcode");
}
static void
@ -841,22 +809,118 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
return true;
}
static bool
vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
switch (opcode) {
case SpvOpFunction: {
assert(b->func == NULL);
b->func = rzalloc(b, struct vtn_function);
const struct glsl_type *result_type =
vtn_value(b, w[1], vtn_value_type_type)->type;
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
const struct glsl_type *func_type =
vtn_value(b, w[4], vtn_value_type_type)->type;
assert(glsl_get_function_return_type(func_type) == result_type);
nir_function *func =
nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
nir_function_overload *overload = nir_function_overload_create(func);
overload->num_params = glsl_get_length(func_type);
overload->params = ralloc_array(overload, nir_parameter,
overload->num_params);
for (unsigned i = 0; i < overload->num_params; i++) {
const struct glsl_function_param *param =
glsl_get_function_param(func_type, i);
overload->params[i].type = param->type;
if (param->in) {
if (param->out) {
overload->params[i].param_type = nir_parameter_inout;
} else {
overload->params[i].param_type = nir_parameter_in;
}
} else {
if (param->out) {
overload->params[i].param_type = nir_parameter_out;
} else {
assert(!"Parameter is neither in nor out");
}
}
}
b->func->overload = overload;
break;
}
case SpvOpFunctionEnd:
b->func = NULL;
break;
case SpvOpFunctionParameter:
break; /* Does nothing */
case SpvOpLabel: {
assert(b->block == NULL);
b->block = rzalloc(b, struct vtn_block);
b->block->label = w;
vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
if (b->func->start_block == NULL) {
/* This is the first block encountered for this function. In this
* case, we set the start block and add it to the list of
* implemented functions that we'll walk later.
*/
b->func->start_block = b->block;
exec_list_push_tail(&b->functions, &b->func->node);
}
break;
}
case SpvOpBranch:
case SpvOpBranchConditional:
case SpvOpSwitch:
case SpvOpKill:
case SpvOpReturn:
case SpvOpReturnValue:
case SpvOpUnreachable:
assert(b->block);
b->block->branch = w;
b->block = NULL;
break;
default:
/* Continue on as per normal */
return true;
}
return true;
}
static bool
vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
switch (opcode) {
case SpvOpLabel: {
struct vtn_block *block = vtn_value(b, w[1], vtn_value_type_block)->block;
struct exec_node *list_tail = exec_list_get_tail(b->cf_list);
nir_cf_node *tail_node = exec_node_data(nir_cf_node, list_tail, node);
assert(tail_node->type == nir_cf_node_block);
nir_block *block = nir_cf_node_as_block(tail_node);
assert(exec_list_is_empty(&block->instr_list));
vtn_push_value(b, w[1], vtn_value_type_block)->block = block;
block->block = nir_cf_node_as_block(tail_node);
assert(exec_list_is_empty(&block->block->instr_list));
break;
}
case SpvOpLoopMerge:
case SpvOpSelectionMerge:
assert(b->merge_block == NULL);
/* TODO: Selection Control */
b->merge_block = vtn_value(b, w[1], vtn_value_type_block)->block;
break;
case SpvOpUndef:
vtn_push_value(b, w[2], vtn_value_type_undef);
break;
@ -878,11 +942,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
vtn_handle_variables(b, opcode, w, count);
break;
case SpvOpFunction:
case SpvOpFunctionEnd:
case SpvOpFunctionParameter:
case SpvOpFunctionCall:
vtn_handle_functions(b, opcode, w, count);
vtn_handle_function_call(b, opcode, w, count);
break;
case SpvOpTextureSample:
@ -1011,11 +1072,66 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
return true;
}
static void
vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
struct vtn_block *end)
{
struct vtn_block *block = start;
while (block != end) {
vtn_foreach_instruction(b, block->label, block->branch,
vtn_handle_body_instruction);
const uint32_t *w = block->branch;
SpvOp branch_op = w[0] & SpvOpCodeMask;
switch (branch_op) {
case SpvOpBranch: {
assert(vtn_value(b, w[1], vtn_value_type_block)->block == end);
return;
}
case SpvOpBranchConditional: {
/* Gather up the branch blocks */
struct vtn_block *then_block =
vtn_value(b, w[2], vtn_value_type_block)->block;
struct vtn_block *else_block =
vtn_value(b, w[3], vtn_value_type_block)->block;
struct vtn_block *merge_block = b->merge_block;
nir_if *if_stmt = nir_if_create(b->shader);
if_stmt->condition = nir_src_for_ssa(vtn_ssa_value(b, w[1]));
nir_cf_node_insert_end(b->cf_list, &if_stmt->cf_node);
struct exec_list *old_list = b->cf_list;
b->cf_list = &if_stmt->then_list;
vtn_walk_blocks(b, then_block, merge_block);
b->cf_list = &if_stmt->else_list;
vtn_walk_blocks(b, else_block, merge_block);
b->cf_list = old_list;
block = merge_block;
continue;
}
case SpvOpSwitch:
case SpvOpKill:
case SpvOpReturn:
case SpvOpReturnValue:
case SpvOpUnreachable:
default:
unreachable("Unhandled opcode");
}
}
}
nir_shader *
spirv_to_nir(const uint32_t *words, size_t word_count,
gl_shader_stage stage,
const nir_shader_compiler_options *options)
{
const uint32_t *word_end = words + word_count;
/* Handle the SPIR-V header (first 4 dwords) */
assert(word_count > 5);
@ -1034,16 +1150,21 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
b->shader = shader;
b->value_id_bound = value_id_bound;
b->values = ralloc_array(b, struct vtn_value, value_id_bound);
const uint32_t *word_end = words + word_count;
exec_list_make_empty(&b->functions);
/* Handle all the preamble instructions */
words = vtn_foreach_instruction(b, words, word_end,
vtn_handle_preamble_instruction);
words = vtn_foreach_instruction(b, words, word_end,
vtn_handle_body_instruction);
assert(words == word_end);
/* Do a very quick CFG analysis pass */
vtn_foreach_instruction(b, words, word_end,
vtn_handle_first_cfg_pass_instruction);
foreach_list_typed(struct vtn_function, func, node, &b->functions) {
b->impl = nir_function_impl_create(func->overload);
b->cf_list = &b->impl->body;
vtn_walk_blocks(b, func->start_block, NULL);
}
ralloc_free(b);