diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 02db779113c..8aae60441c6 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3761,6 +3761,17 @@ nir_shader_get_entrypoint(const nir_shader *shader) return func->impl; } +static inline nir_function * +nir_shader_get_function_for_name(const nir_shader *shader, const char *name) +{ + nir_foreach_function(func, shader) { + if (strcmp(func->name, name) == 0) + return func; + } + + return NULL; +} + void nir_remove_non_entrypoints(nir_shader *shader); nir_shader *nir_shader_create(void *mem_ctx, diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index 8262918912d..c41628fd475 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -139,24 +139,15 @@ static nir_function *mangle_and_find(struct vtn_builder *b, struct vtn_type **src_types) { char *mname; - nir_function *found = NULL; vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname); + /* try and find in current shader first. */ - nir_foreach_function(funcs, b->shader) { - if (!strcmp(funcs->name, mname)) { - found = funcs; - break; - } - } + nir_function *found = nir_shader_get_function_for_name(b->shader, mname); + /* if not found here find in clc shader and create a decl mirroring it */ if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) { - nir_foreach_function(funcs, b->options->clc_shader) { - if (!strcmp(funcs->name, mname)) { - found = funcs; - break; - } - } + found = nir_shader_get_function_for_name(b->options->clc_shader, mname); if (found) { nir_function *decl = nir_function_create(b->shader, mname); decl->num_params = found->num_params;