diff --git a/src/compiler/spirv/nir_spirv.h b/src/compiler/spirv/nir_spirv.h index 1369a0e6b47..8984e59b815 100644 --- a/src/compiler/spirv/nir_spirv.h +++ b/src/compiler/spirv/nir_spirv.h @@ -146,6 +146,10 @@ nir_shader *spirv_to_nir(const uint32_t *words, size_t word_count, const struct spirv_to_nir_options *options, const nir_shader_compiler_options *nir_options); +bool +spirv_library_to_nir_builder(FILE *fp, const uint32_t *words, size_t word_count, + const struct spirv_to_nir_options *options); + #ifdef __cplusplus } #endif diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 4729916b807..ccb5497edc6 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -7142,3 +7142,170 @@ spirv_to_nir(const uint32_t *words, size_t word_count, return shader; } + +static bool +func_to_nir_builder(FILE *fp, struct vtn_function *func) +{ + nir_function *nir_func = func->nir_func; + struct vtn_type *return_type = func->type->return_type; + bool returns = return_type->base_type != vtn_base_type_void; + + if (returns && return_type->base_type != vtn_base_type_scalar && + return_type->base_type != vtn_base_type_vector) { + fprintf(stderr, "Unsupported return type for %s", nir_func->name); + return false; + } + + /* If there is a return type, the first NIR parameter is the return deref, + * so offset by that for logical parameter iteration. + */ + unsigned first_param = returns ? 1 : 0; + + /* Generate function signature */ + fprintf(fp, "static inline %s\n", returns ? "nir_def *": "void"); + fprintf(fp, "%s(nir_builder *b", nir_func->name); + + /* TODO: Can we recover parameter names? */ + for (unsigned i = first_param; i < nir_func->num_params; ++i) { + fprintf(fp, ", nir_def *arg%u", i); + } + + fprintf(fp, ")\n{\n"); + + /* Validate inputs. nir_validate will do this too, but the + * errors/backtraces from these asserts should be nicer. + */ + for (unsigned i = first_param; i < nir_func->num_params; ++i) { + nir_parameter *param = &nir_func->params[i]; + fprintf(fp, " assert(arg%u->bit_size == %u);\n", i, param->bit_size); + fprintf(fp, " assert(arg%u->num_components == %u);\n", i, + param->num_components); + fprintf(fp, "\n"); + } + + /* Find the function to call. If not found, create a prototype */ + fprintf(fp, " nir_function *func = nir_shader_get_function_for_name(b->shader, \"%s\");\n", + nir_func->name); + fprintf(fp, "\n"); + fprintf(fp, " if (!func) {\n"); + fprintf(fp, " func = nir_function_create(b->shader, \"%s\");\n", + nir_func->name); + fprintf(fp, " func->num_params = %u;\n", nir_func->num_params); + fprintf(fp, " func->params = ralloc_array(b->shader, nir_parameter, func->num_params);\n"); + + for (unsigned i = 0; i < nir_func->num_params; ++i) { + fprintf(fp, "\n"); + fprintf(fp, " func->params[%u].bit_size = %u;\n", i, + nir_func->params[i].bit_size); + fprintf(fp, " func->params[%u].num_components = %u;\n", i, + nir_func->params[i].num_components); + } + + fprintf(fp, " }\n\n"); + + + if (returns) { + /* We assume that vec3 variables are lowered to vec4. Mirror that here so + * we don't need to lower vec3 to vec4 again at link-time. + */ + assert(glsl_type_is_vector_or_scalar(return_type->type)); + unsigned elements = return_type->type->vector_elements; + if (elements == 3) + elements = 4; + + /* Reconstruct the return type. */ + fprintf(fp, " const struct glsl_type *ret_type = glsl_vector_type(%u, %u);\n", + return_type->type->base_type, elements); + + /* With the type, we can make a variable and get a deref to pass in */ + fprintf(fp, " nir_variable *ret = nir_local_variable_create(b->impl, ret_type, \"return\");\n"); + fprintf(fp, " nir_deref_instr *deref = nir_build_deref_var(b, ret);\n"); + + /* XXX: This is a hack due to ptr size differing between KERNEL and other + * shader stages. This needs to be fixed in core NIR. + */ + fprintf(fp, " deref->def.bit_size = %u;\n", nir_func->params[0].bit_size); + fprintf(fp, "\n"); + } + + /* Call the function */ + fprintf(fp, " nir_call(b, func"); + + if (returns) + fprintf(fp, ", &deref->def"); + + for (unsigned i = first_param; i < nir_func->num_params; ++i) + fprintf(fp, ", arg%u", i); + + fprintf(fp, ");\n"); + + /* Load the return value if any, undoing the vec3->vec4 lowering. */ + if (returns) { + fprintf(fp, "\n"); + + if (return_type->type->vector_elements == 3) + fprintf(fp, " return nir_trim_vector(b, nir_load_deref(b, deref), 3);\n"); + else + fprintf(fp, " return nir_load_deref(b, deref);\n"); + } + + fprintf(fp, "}\n\n"); + return true; +} + +bool +spirv_library_to_nir_builder(FILE *fp, const uint32_t *words, size_t word_count, + const struct spirv_to_nir_options *options) +{ +#ifndef NDEBUG + static once_flag initialized_debug_flag = ONCE_FLAG_INIT; + call_once(&initialized_debug_flag, initialize_mesa_spirv_debug); +#endif + + const uint32_t *word_end = words + word_count; + + struct vtn_builder *b = vtn_create_builder(words, word_count, + MESA_SHADER_KERNEL, "placeholder name", + options); + + if (b == NULL) + return false; + + /* See also _vtn_fail() */ + if (vtn_setjmp(b->fail_jump)) { + ralloc_free(b); + return false; + } + + b->shader = nir_shader_create(b, MESA_SHADER_KERNEL, + &(const nir_shader_compiler_options){0}, NULL); + + /* Skip the SPIR-V header, handled at vtn_create_builder */ + words+= 5; + + /* Handle all the preamble instructions */ + words = vtn_foreach_instruction(b, words, word_end, + vtn_handle_preamble_instruction); + + /* Handle all variable, type, and constant instructions */ + words = vtn_foreach_instruction(b, words, word_end, + vtn_handle_variable_or_type_instruction); + + /* Set types on all vtn_values */ + vtn_foreach_instruction(b, words, word_end, vtn_set_instruction_result_type); + + vtn_build_cfg(b, words, word_end); + + fprintf(fp, "#include \"compiler/nir/nir_builder.h\"\n\n"); + + vtn_foreach_function(func, &b->functions) { + if (func->linkage != SpvLinkageTypeExport) + continue; + + if (!func_to_nir_builder(fp, func)) + return false; + } + + ralloc_free(b); + return true; +}