vtn: Add spirv_library_to_nir_builder feature

This new entrypoint takes in a SPIR-V blob and generates a header containing
a static inline nir_builder-family function for each function in the SPIR-V
library. The generated function will look for the function in the shader and, if
not found, insert a new nir_function with the appropriate signature -- to be
linked with the library later. Then, it will call the function, with the
appropriate gymnastics to handle return values as necessary.

This makes it super convenient to wrap CL libraries for use in a NIR pass.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25498>
This commit is contained in:
Alyssa Rosenzweig 2023-10-01 19:45:39 -04:00 committed by Marge Bot
parent a2d3c74094
commit f164edfe71
2 changed files with 171 additions and 0 deletions

View file

@ -146,6 +146,10 @@ nir_shader *spirv_to_nir(const uint32_t *words, size_t word_count,
const struct spirv_to_nir_options *options,
const nir_shader_compiler_options *nir_options);
bool
spirv_library_to_nir_builder(FILE *fp, const uint32_t *words, size_t word_count,
const struct spirv_to_nir_options *options);
#ifdef __cplusplus
}
#endif

View file

@ -7142,3 +7142,170 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
return shader;
}
static bool
func_to_nir_builder(FILE *fp, struct vtn_function *func)
{
nir_function *nir_func = func->nir_func;
struct vtn_type *return_type = func->type->return_type;
bool returns = return_type->base_type != vtn_base_type_void;
if (returns && return_type->base_type != vtn_base_type_scalar &&
return_type->base_type != vtn_base_type_vector) {
fprintf(stderr, "Unsupported return type for %s", nir_func->name);
return false;
}
/* If there is a return type, the first NIR parameter is the return deref,
* so offset by that for logical parameter iteration.
*/
unsigned first_param = returns ? 1 : 0;
/* Generate function signature */
fprintf(fp, "static inline %s\n", returns ? "nir_def *": "void");
fprintf(fp, "%s(nir_builder *b", nir_func->name);
/* TODO: Can we recover parameter names? */
for (unsigned i = first_param; i < nir_func->num_params; ++i) {
fprintf(fp, ", nir_def *arg%u", i);
}
fprintf(fp, ")\n{\n");
/* Validate inputs. nir_validate will do this too, but the
* errors/backtraces from these asserts should be nicer.
*/
for (unsigned i = first_param; i < nir_func->num_params; ++i) {
nir_parameter *param = &nir_func->params[i];
fprintf(fp, " assert(arg%u->bit_size == %u);\n", i, param->bit_size);
fprintf(fp, " assert(arg%u->num_components == %u);\n", i,
param->num_components);
fprintf(fp, "\n");
}
/* Find the function to call. If not found, create a prototype */
fprintf(fp, " nir_function *func = nir_shader_get_function_for_name(b->shader, \"%s\");\n",
nir_func->name);
fprintf(fp, "\n");
fprintf(fp, " if (!func) {\n");
fprintf(fp, " func = nir_function_create(b->shader, \"%s\");\n",
nir_func->name);
fprintf(fp, " func->num_params = %u;\n", nir_func->num_params);
fprintf(fp, " func->params = ralloc_array(b->shader, nir_parameter, func->num_params);\n");
for (unsigned i = 0; i < nir_func->num_params; ++i) {
fprintf(fp, "\n");
fprintf(fp, " func->params[%u].bit_size = %u;\n", i,
nir_func->params[i].bit_size);
fprintf(fp, " func->params[%u].num_components = %u;\n", i,
nir_func->params[i].num_components);
}
fprintf(fp, " }\n\n");
if (returns) {
/* We assume that vec3 variables are lowered to vec4. Mirror that here so
* we don't need to lower vec3 to vec4 again at link-time.
*/
assert(glsl_type_is_vector_or_scalar(return_type->type));
unsigned elements = return_type->type->vector_elements;
if (elements == 3)
elements = 4;
/* Reconstruct the return type. */
fprintf(fp, " const struct glsl_type *ret_type = glsl_vector_type(%u, %u);\n",
return_type->type->base_type, elements);
/* With the type, we can make a variable and get a deref to pass in */
fprintf(fp, " nir_variable *ret = nir_local_variable_create(b->impl, ret_type, \"return\");\n");
fprintf(fp, " nir_deref_instr *deref = nir_build_deref_var(b, ret);\n");
/* XXX: This is a hack due to ptr size differing between KERNEL and other
* shader stages. This needs to be fixed in core NIR.
*/
fprintf(fp, " deref->def.bit_size = %u;\n", nir_func->params[0].bit_size);
fprintf(fp, "\n");
}
/* Call the function */
fprintf(fp, " nir_call(b, func");
if (returns)
fprintf(fp, ", &deref->def");
for (unsigned i = first_param; i < nir_func->num_params; ++i)
fprintf(fp, ", arg%u", i);
fprintf(fp, ");\n");
/* Load the return value if any, undoing the vec3->vec4 lowering. */
if (returns) {
fprintf(fp, "\n");
if (return_type->type->vector_elements == 3)
fprintf(fp, " return nir_trim_vector(b, nir_load_deref(b, deref), 3);\n");
else
fprintf(fp, " return nir_load_deref(b, deref);\n");
}
fprintf(fp, "}\n\n");
return true;
}
bool
spirv_library_to_nir_builder(FILE *fp, const uint32_t *words, size_t word_count,
const struct spirv_to_nir_options *options)
{
#ifndef NDEBUG
static once_flag initialized_debug_flag = ONCE_FLAG_INIT;
call_once(&initialized_debug_flag, initialize_mesa_spirv_debug);
#endif
const uint32_t *word_end = words + word_count;
struct vtn_builder *b = vtn_create_builder(words, word_count,
MESA_SHADER_KERNEL, "placeholder name",
options);
if (b == NULL)
return false;
/* See also _vtn_fail() */
if (vtn_setjmp(b->fail_jump)) {
ralloc_free(b);
return false;
}
b->shader = nir_shader_create(b, MESA_SHADER_KERNEL,
&(const nir_shader_compiler_options){0}, NULL);
/* Skip the SPIR-V header, handled at vtn_create_builder */
words+= 5;
/* Handle all the preamble instructions */
words = vtn_foreach_instruction(b, words, word_end,
vtn_handle_preamble_instruction);
/* Handle all variable, type, and constant instructions */
words = vtn_foreach_instruction(b, words, word_end,
vtn_handle_variable_or_type_instruction);
/* Set types on all vtn_values */
vtn_foreach_instruction(b, words, word_end, vtn_set_instruction_result_type);
vtn_build_cfg(b, words, word_end);
fprintf(fp, "#include \"compiler/nir/nir_builder.h\"\n\n");
vtn_foreach_function(func, &b->functions) {
if (func->linkage != SpvLinkageTypeExport)
continue;
if (!func_to_nir_builder(fp, func))
return false;
}
ralloc_free(b);
return true;
}