nir: Add indirect calls

Used to jump to a function referred to by a runtime pointer.

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29577>
This commit is contained in:
Friedrich Vock 2024-01-07 22:14:51 +01:00 committed by Marge Bot
parent 4432cf0a58
commit bb40284f76
11 changed files with 74 additions and 4 deletions

View file

@ -500,6 +500,7 @@ nir_function_create(nir_shader *shader, const char *name)
func->is_preamble = false;
func->dont_inline = false;
func->should_inline = false;
func->driver_attributes = 0;
func->is_subroutine = false;
func->is_tmp_globals_wrapper = false;
func->subroutine_index = 0;

View file

@ -1871,6 +1871,10 @@ typedef struct {
nir_instr instr;
struct nir_function *callee;
/* If this function call is indirect, the function pointer to call.
* Otherwise, null initialized.
*/
nir_src indirect_callee;
unsigned num_params;
nir_src params[];
@ -3593,6 +3597,15 @@ typedef struct {
nir_variable_mode mode;
/* Drivers may optionally stash flags here describing the parameter.
* For example, this might encode whether the driver expects the value
* to be uniform or divergent, if the driver handles divergent parameters
* differently from uniform ones.
*
* NIR will preserve this value but does not interpret it in any way.
*/
uint32_t driver_attributes;
/* The type of the function param */
const struct glsl_type *type;
@ -3618,6 +3631,14 @@ typedef struct nir_function {
*/
nir_function_impl *impl;
/* Drivers may optionally stash flags here describing the function call.
* For example, this might encode the ABI used for the call if a driver
* supports multiple ABIs.
*
* NIR will preserve this value but does not interpret it in any way.
*/
uint32_t driver_attributes;
bool is_entrypoint;
/* from SPIR-V linkage, only for libraries */
bool is_exported;

View file

@ -2243,6 +2243,22 @@ nir_build_call(nir_builder *build, nir_function *func, size_t count,
nir_builder_instr_insert(build, &call->instr);
}
static inline void
nir_build_indirect_call(nir_builder *build, nir_function *func, nir_def *callee,
size_t count, nir_def **args)
{
assert(count == func->num_params && "parameter count must match");
assert(!func->impl && "cannot call directly defined functions indirectly");
nir_call_instr *call = nir_call_instr_create(build->shader, func);
for (unsigned i = 0; i < func->num_params; ++i) {
call->params[i] = nir_src_for_ssa(args[i]);
}
call->indirect_callee = nir_src_for_ssa(callee);
nir_builder_instr_insert(build, &call->instr);
}
static inline void
nir_discard(nir_builder *build)
{
@ -2276,6 +2292,12 @@ nir_build_string(nir_builder *build, const char *value);
nir_build_call(build, func, ARRAY_SIZE(args), args); \
} while (0)
#define nir_call_indirect(build, func, callee, ...) \
do { \
nir_def *_args[] = { __VA_ARGS__ }; \
nir_build_indirect_call(build, func, callee, ARRAY_SIZE(_args), _args); \
} while (0)
nir_def *
nir_compare_func(nir_builder *b, enum compare_func func,
nir_def *src0, nir_def *src1);

View file

@ -719,6 +719,7 @@ nir_function_clone(nir_shader *ns, const nir_function *fxn)
nfxn->should_inline = fxn->should_inline;
nfxn->dont_inline = fxn->dont_inline;
nfxn->is_subroutine = fxn->is_subroutine;
nfxn->driver_attributes = fxn->driver_attributes;
nfxn->is_tmp_globals_wrapper = fxn->is_tmp_globals_wrapper;
nfxn->num_subroutine_types = fxn->num_subroutine_types;
nfxn->subroutine_index = fxn->subroutine_index;

View file

@ -1120,8 +1120,9 @@ instr_is_loop_invariant(nir_instr *instr, struct divergence_state *state)
case nir_instr_type_deref:
case nir_instr_type_tex:
return nir_foreach_src(instr, src_invariant, state->loop);
case nir_instr_type_phi:
case nir_instr_type_call:
return false;
case nir_instr_type_phi:
case nir_instr_type_parallel_copy:
default:
unreachable("NIR divergence analysis: Unsupported instruction type.");
@ -1146,9 +1147,10 @@ update_instr_divergence(nir_instr *instr, struct divergence_state *state)
return visit_deref(state->shader, nir_instr_as_deref(instr), state);
case nir_instr_type_debug_info:
return false;
case nir_instr_type_call:
return false;
case nir_instr_type_jump:
case nir_instr_type_phi:
case nir_instr_type_call:
case nir_instr_type_parallel_copy:
default:
unreachable("NIR divergence analysis: Unsupported instruction type.");

View file

@ -204,7 +204,10 @@ static bool inline_functions_pass(nir_builder *b,
return false;
nir_call_instr *call = nir_instr_as_call(instr);
assert(call->callee->impl);
if (!call->callee->impl)
return false;
assert(!call->indirect_callee.ssa);
if (b->shader->options->driver_functions &&
b->shader->info.stage == MESA_SHADER_KERNEL) {

View file

@ -954,7 +954,8 @@ gather_func_info(nir_function_impl *func, nir_shader *shader,
nir_call_instr *call = nir_instr_as_call(instr);
nir_function_impl *impl = call->callee->impl;
assert(impl || !"nir_shader_gather_info only works with linked shaders");
if (!call->indirect_callee.ssa)
assert(impl || !"nir_shader_gather_info only works with linked shaders");
gather_func_info(impl, shader, visited_funcs, dead_ctx);
break;
}

View file

@ -107,6 +107,8 @@ nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state)
}
case nir_instr_type_call: {
nir_call_instr *call = nir_instr_as_call(instr);
if (call->indirect_callee.ssa && !_nir_visit_src(&call->indirect_callee, cb, state))
return false;
for (unsigned i = 0; i < call->num_params; i++) {
if (!_nir_visit_src(&call->params[i], cb, state))
return false;

View file

@ -1905,7 +1905,14 @@ print_call_instr(nir_call_instr *instr, print_state *state)
print_no_dest_padding(state);
bool indirect = instr->indirect_callee.ssa;
fprintf(fp, "call %s ", instr->callee->name);
if (indirect) {
fprintf(fp, "(indirect ");
print_src(&instr->indirect_callee, state, nir_type_invalid);
fprintf(fp, ") ");
}
for (unsigned i = 0; i < instr->num_params; i++) {
if (i != 0)

View file

@ -1983,6 +1983,8 @@ write_function(write_ctx *ctx, const nir_function *fxn)
blob_write_uint32(ctx->blob, fxn->workgroup_size[2]);
}
blob_write_uint32(ctx->blob, fxn->driver_attributes);
blob_write_uint32(ctx->blob, fxn->subroutine_index);
blob_write_uint32(ctx->blob, fxn->num_subroutine_types);
for (unsigned i = 0; i < fxn->num_subroutine_types; i++) {
@ -2011,6 +2013,7 @@ write_function(write_ctx *ctx, const nir_function *fxn)
encode_type_to_blob(ctx->blob, fxn->params[i].type);
blob_write_uint32(ctx->blob, encode_deref_modes(fxn->params[i].mode));
blob_write_uint32(ctx->blob, fxn->params[i].driver_attributes);
}
/* At first glance, it looks like we should write the function_impl here.
@ -2036,6 +2039,7 @@ read_function(read_ctx *ctx)
fxn->workgroup_size[2] = blob_read_uint32(ctx->blob);
}
fxn->driver_attributes = blob_read_uint32(ctx->blob);
fxn->subroutine_index = blob_read_uint32(ctx->blob);
fxn->num_subroutine_types = blob_read_uint32(ctx->blob);
for (unsigned i = 0; i < fxn->num_subroutine_types; i++) {
@ -2058,6 +2062,7 @@ read_function(read_ctx *ctx)
fxn->params[i].is_uniform = val & (1u << 17);
fxn->params[i].type = decode_type_from_blob(ctx->blob);
fxn->params[i].mode = decode_deref_modes(blob_read_uint32(ctx->blob));
fxn->params[i].driver_attributes = blob_read_uint32(ctx->blob);
}
fxn->is_entrypoint = flags & 0x1;

View file

@ -971,6 +971,11 @@ validate_call_instr(nir_call_instr *instr, validate_state *state)
{
validate_assert(state, instr->num_params == instr->callee->num_params);
if (instr->indirect_callee.ssa) {
validate_assert(state, !instr->callee->impl);
validate_src(&instr->indirect_callee, state);
}
for (unsigned i = 0; i < instr->num_params; i++) {
validate_sized_src(&instr->params[i], state,
instr->callee->params[i].bit_size,