diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index 50c0e62519c..be23a142ddb 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -500,6 +500,7 @@ nir_function_create(nir_shader *shader, const char *name) func->is_preamble = false; func->dont_inline = false; func->should_inline = false; + func->driver_attributes = 0; func->is_subroutine = false; func->is_tmp_globals_wrapper = false; func->subroutine_index = 0; diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 6b7b8c90134..e258faa42f4 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1871,6 +1871,10 @@ typedef struct { nir_instr instr; struct nir_function *callee; + /* If this function call is indirect, the function pointer to call. + * Otherwise, null initialized. + */ + nir_src indirect_callee; unsigned num_params; nir_src params[]; @@ -3593,6 +3597,15 @@ typedef struct { nir_variable_mode mode; + /* Drivers may optionally stash flags here describing the parameter. + * For example, this might encode whether the driver expects the value + * to be uniform or divergent, if the driver handles divergent parameters + * differently from uniform ones. + * + * NIR will preserve this value but does not interpret it in any way. + */ + uint32_t driver_attributes; + /* The type of the function param */ const struct glsl_type *type; @@ -3618,6 +3631,14 @@ typedef struct nir_function { */ nir_function_impl *impl; + /* Drivers may optionally stash flags here describing the function call. + * For example, this might encode the ABI used for the call if a driver + * supports multiple ABIs. + * + * NIR will preserve this value but does not interpret it in any way. + */ + uint32_t driver_attributes; + bool is_entrypoint; /* from SPIR-V linkage, only for libraries */ bool is_exported; diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index eed9964d3ab..f07d03c6a1c 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -2243,6 +2243,22 @@ nir_build_call(nir_builder *build, nir_function *func, size_t count, nir_builder_instr_insert(build, &call->instr); } +static inline void +nir_build_indirect_call(nir_builder *build, nir_function *func, nir_def *callee, + size_t count, nir_def **args) +{ + assert(count == func->num_params && "parameter count must match"); + assert(!func->impl && "cannot call directly defined functions indirectly"); + nir_call_instr *call = nir_call_instr_create(build->shader, func); + + for (unsigned i = 0; i < func->num_params; ++i) { + call->params[i] = nir_src_for_ssa(args[i]); + } + call->indirect_callee = nir_src_for_ssa(callee); + + nir_builder_instr_insert(build, &call->instr); +} + static inline void nir_discard(nir_builder *build) { @@ -2276,6 +2292,12 @@ nir_build_string(nir_builder *build, const char *value); nir_build_call(build, func, ARRAY_SIZE(args), args); \ } while (0) +#define nir_call_indirect(build, func, callee, ...) \ + do { \ + nir_def *_args[] = { __VA_ARGS__ }; \ + nir_build_indirect_call(build, func, callee, ARRAY_SIZE(_args), _args); \ + } while (0) + nir_def * nir_compare_func(nir_builder *b, enum compare_func func, nir_def *src0, nir_def *src1); diff --git a/src/compiler/nir/nir_clone.c b/src/compiler/nir/nir_clone.c index ecd985aacba..19b3b05a527 100644 --- a/src/compiler/nir/nir_clone.c +++ b/src/compiler/nir/nir_clone.c @@ -719,6 +719,7 @@ nir_function_clone(nir_shader *ns, const nir_function *fxn) nfxn->should_inline = fxn->should_inline; nfxn->dont_inline = fxn->dont_inline; nfxn->is_subroutine = fxn->is_subroutine; + nfxn->driver_attributes = fxn->driver_attributes; nfxn->is_tmp_globals_wrapper = fxn->is_tmp_globals_wrapper; nfxn->num_subroutine_types = fxn->num_subroutine_types; nfxn->subroutine_index = fxn->subroutine_index; diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index 792fdd0ede2..4bc575d73b0 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -1120,8 +1120,9 @@ instr_is_loop_invariant(nir_instr *instr, struct divergence_state *state) case nir_instr_type_deref: case nir_instr_type_tex: return nir_foreach_src(instr, src_invariant, state->loop); - case nir_instr_type_phi: case nir_instr_type_call: + return false; + case nir_instr_type_phi: case nir_instr_type_parallel_copy: default: unreachable("NIR divergence analysis: Unsupported instruction type."); @@ -1146,9 +1147,10 @@ update_instr_divergence(nir_instr *instr, struct divergence_state *state) return visit_deref(state->shader, nir_instr_as_deref(instr), state); case nir_instr_type_debug_info: return false; + case nir_instr_type_call: + return false; case nir_instr_type_jump: case nir_instr_type_phi: - case nir_instr_type_call: case nir_instr_type_parallel_copy: default: unreachable("NIR divergence analysis: Unsupported instruction type."); diff --git a/src/compiler/nir/nir_functions.c b/src/compiler/nir/nir_functions.c index 6a3c35fc57e..514e831920c 100644 --- a/src/compiler/nir/nir_functions.c +++ b/src/compiler/nir/nir_functions.c @@ -204,7 +204,10 @@ static bool inline_functions_pass(nir_builder *b, return false; nir_call_instr *call = nir_instr_as_call(instr); - assert(call->callee->impl); + if (!call->callee->impl) + return false; + + assert(!call->indirect_callee.ssa); if (b->shader->options->driver_functions && b->shader->info.stage == MESA_SHADER_KERNEL) { diff --git a/src/compiler/nir/nir_gather_info.c b/src/compiler/nir/nir_gather_info.c index 7139a780102..1340f6678bb 100644 --- a/src/compiler/nir/nir_gather_info.c +++ b/src/compiler/nir/nir_gather_info.c @@ -954,7 +954,8 @@ gather_func_info(nir_function_impl *func, nir_shader *shader, nir_call_instr *call = nir_instr_as_call(instr); nir_function_impl *impl = call->callee->impl; - assert(impl || !"nir_shader_gather_info only works with linked shaders"); + if (!call->indirect_callee.ssa) + assert(impl || !"nir_shader_gather_info only works with linked shaders"); gather_func_info(impl, shader, visited_funcs, dead_ctx); break; } diff --git a/src/compiler/nir/nir_inline_helpers.h b/src/compiler/nir/nir_inline_helpers.h index 8f3994f5353..17f2581ccee 100644 --- a/src/compiler/nir/nir_inline_helpers.h +++ b/src/compiler/nir/nir_inline_helpers.h @@ -107,6 +107,8 @@ nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state) } case nir_instr_type_call: { nir_call_instr *call = nir_instr_as_call(instr); + if (call->indirect_callee.ssa && !_nir_visit_src(&call->indirect_callee, cb, state)) + return false; for (unsigned i = 0; i < call->num_params; i++) { if (!_nir_visit_src(&call->params[i], cb, state)) return false; diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index 27b32d7d97c..20aa9acedb3 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -1905,7 +1905,14 @@ print_call_instr(nir_call_instr *instr, print_state *state) print_no_dest_padding(state); + bool indirect = instr->indirect_callee.ssa; + fprintf(fp, "call %s ", instr->callee->name); + if (indirect) { + fprintf(fp, "(indirect "); + print_src(&instr->indirect_callee, state, nir_type_invalid); + fprintf(fp, ") "); + } for (unsigned i = 0; i < instr->num_params; i++) { if (i != 0) diff --git a/src/compiler/nir/nir_serialize.c b/src/compiler/nir/nir_serialize.c index 7bdd0da5082..aaed31f4afd 100644 --- a/src/compiler/nir/nir_serialize.c +++ b/src/compiler/nir/nir_serialize.c @@ -1983,6 +1983,8 @@ write_function(write_ctx *ctx, const nir_function *fxn) blob_write_uint32(ctx->blob, fxn->workgroup_size[2]); } + blob_write_uint32(ctx->blob, fxn->driver_attributes); + blob_write_uint32(ctx->blob, fxn->subroutine_index); blob_write_uint32(ctx->blob, fxn->num_subroutine_types); for (unsigned i = 0; i < fxn->num_subroutine_types; i++) { @@ -2011,6 +2013,7 @@ write_function(write_ctx *ctx, const nir_function *fxn) encode_type_to_blob(ctx->blob, fxn->params[i].type); blob_write_uint32(ctx->blob, encode_deref_modes(fxn->params[i].mode)); + blob_write_uint32(ctx->blob, fxn->params[i].driver_attributes); } /* At first glance, it looks like we should write the function_impl here. @@ -2036,6 +2039,7 @@ read_function(read_ctx *ctx) fxn->workgroup_size[2] = blob_read_uint32(ctx->blob); } + fxn->driver_attributes = blob_read_uint32(ctx->blob); fxn->subroutine_index = blob_read_uint32(ctx->blob); fxn->num_subroutine_types = blob_read_uint32(ctx->blob); for (unsigned i = 0; i < fxn->num_subroutine_types; i++) { @@ -2058,6 +2062,7 @@ read_function(read_ctx *ctx) fxn->params[i].is_uniform = val & (1u << 17); fxn->params[i].type = decode_type_from_blob(ctx->blob); fxn->params[i].mode = decode_deref_modes(blob_read_uint32(ctx->blob)); + fxn->params[i].driver_attributes = blob_read_uint32(ctx->blob); } fxn->is_entrypoint = flags & 0x1; diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c index 1a43af7b96b..98a1e492234 100644 --- a/src/compiler/nir/nir_validate.c +++ b/src/compiler/nir/nir_validate.c @@ -971,6 +971,11 @@ validate_call_instr(nir_call_instr *instr, validate_state *state) { validate_assert(state, instr->num_params == instr->callee->num_params); + if (instr->indirect_callee.ssa) { + validate_assert(state, !instr->callee->impl); + validate_src(&instr->indirect_callee, state); + } + for (unsigned i = 0; i < instr->num_params; i++) { validate_sized_src(&instr->params[i], state, instr->callee->params[i].bit_size,