diff --git a/src/microsoft/clc/clc_compiler.c b/src/microsoft/clc/clc_compiler.c index 3af1493efaa..d1127ca8b00 100644 --- a/src/microsoft/clc/clc_compiler.c +++ b/src/microsoft/clc/clc_compiler.c @@ -575,6 +575,34 @@ struct clc_libclc * return ctx; } +bool +clc_compile_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir) +{ + return clc_c_to_spir(args, logger, out_spir) >= 0; +} + +void +clc_free_spir(struct clc_binary *spir) +{ + clc_free_spir_binary(spir); +} + +bool +clc_compile_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + if (clc_spir_to_spirv(in_spir, logger, out_spirv) < 0) + return false; + + if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV) + clc_dump_spirv(out_spirv, stdout); + + return true; +} + void clc_free_spirv(struct clc_binary *spirv) { diff --git a/src/microsoft/clc/clc_compiler.h b/src/microsoft/clc/clc_compiler.h index 3f2cfea2eeb..0cf508deb48 100644 --- a/src/microsoft/clc/clc_compiler.h +++ b/src/microsoft/clc/clc_compiler.h @@ -200,6 +200,19 @@ void clc_libclc_serialize(struct clc_libclc *lib, void **serialized, size_t *siz void clc_libclc_free_serialized(void *serialized); struct clc_libclc *clc_libclc_deserialize(void *serialized, size_t size); +bool +clc_compile_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir); + +void +clc_free_spir(struct clc_binary *spir); + +bool +clc_compile_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv); + void clc_free_spirv(struct clc_binary *spirv); diff --git a/src/microsoft/clc/clc_helpers.cpp b/src/microsoft/clc/clc_helpers.cpp index 18d056ea0d2..62ae00282ae 100644 --- a/src/microsoft/clc/clc_helpers.cpp +++ b/src/microsoft/clc/clc_helpers.cpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include #include #include #include @@ -576,10 +578,9 @@ clc_free_kernels_info(const struct clc_kernel_info *kernels, free((void *)kernels); } -int -clc_c_to_spirv(const struct clc_compile_args *args, - const struct clc_logger *logger, - struct clc_binary *out_spirv) +static std::pair, std::unique_ptr> +clc_compile_to_llvm_module(const struct clc_compile_args *args, + const struct clc_logger *logger) { LLVMInitializeAllTargets(); LLVMInitializeAllTargetInfos(); @@ -626,13 +627,13 @@ clc_c_to_spirv(const struct clc_compile_args *args, diag)) { log += "Couldn't create Clang invocation.\n"; clc_error(logger, log.c_str()); - return -1; + return {}; } if (diag.hasErrorOccurred()) { log += "Errors occurred during Clang invocation.\n"; clc_error(logger, log.c_str()); - return -1; + return {}; } // This is a workaround for a Clang bug which causes the number @@ -696,10 +697,19 @@ clc_c_to_spirv(const struct clc_compile_args *args, if (!c->ExecuteAction(act)) { log += "Error executing LLVM compilation action.\n"; clc_error(logger, log.c_str()); - return -1; + return {}; } - auto mod = act.takeModule(); + return { act.takeModule(), std::move(llvm_ctx) }; +} + +static int +llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod, + std::unique_ptr context, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + std::string log; std::ostringstream spv_stream; if (!::llvm::writeSpirv(mod.get(), spv_stream, log)) { log += "Translation from LLVM IR to SPIR-V failed.\n"; @@ -715,34 +725,54 @@ clc_c_to_spirv(const struct clc_compile_args *args, return 0; } -static const char * -spv_result_to_str(spv_result_t res) +int +clc_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir) { - switch (res) { - case SPV_SUCCESS: return "success"; - case SPV_UNSUPPORTED: return "unsupported"; - case SPV_END_OF_STREAM: return "end of stream"; - case SPV_WARNING: return "warning"; - case SPV_FAILED_MATCH: return "failed match"; - case SPV_REQUESTED_TERMINATION: return "requested termination"; - case SPV_ERROR_INTERNAL: return "internal error"; - case SPV_ERROR_OUT_OF_MEMORY: return "out of memory"; - case SPV_ERROR_INVALID_POINTER: return "invalid pointer"; - case SPV_ERROR_INVALID_BINARY: return "invalid binary"; - case SPV_ERROR_INVALID_TEXT: return "invalid text"; - case SPV_ERROR_INVALID_TABLE: return "invalid table"; - case SPV_ERROR_INVALID_VALUE: return "invalid value"; - case SPV_ERROR_INVALID_DIAGNOSTIC: return "invalid diagnostic"; - case SPV_ERROR_INVALID_LOOKUP: return "invalid lookup"; - case SPV_ERROR_INVALID_ID: return "invalid id"; - case SPV_ERROR_INVALID_CFG: return "invalid config"; - case SPV_ERROR_INVALID_LAYOUT: return "invalid layout"; - case SPV_ERROR_INVALID_CAPABILITY: return "invalid capability"; - case SPV_ERROR_INVALID_DATA: return "invalid data"; - case SPV_ERROR_MISSING_EXTENSION: return "missing extension"; - case SPV_ERROR_WRONG_VERSION: return "wrong version"; - default: return "unknown error"; - } + auto pair = clc_compile_to_llvm_module(args, logger); + if (!pair.first) + return -1; + + ::llvm::SmallVector buffer; + ::llvm::BitcodeWriter writer(buffer); + writer.writeModule(*pair.first); + + out_spir->size = buffer.size_in_bytes(); + out_spir->data = malloc(out_spir->size); + memcpy(out_spir->data, buffer.data(), out_spir->size); + + return 0; +} + +int +clc_c_to_spirv(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + auto pair = clc_compile_to_llvm_module(args, logger); + if (!pair.first) + return -1; + return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), logger, out_spirv); +} + +int +clc_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + LLVMInitializeAllTargets(); + LLVMInitializeAllTargetInfos(); + LLVMInitializeAllTargetMCs(); + LLVMInitializeAllAsmPrinters(); + + std::unique_ptr llvm_ctx{ new LLVMContext }; + ::llvm::StringRef spir_ref(static_cast(in_spir->data), in_spir->size); + auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, ""), *llvm_ctx); + if (!mod) + return -1; + + return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), logger, out_spirv); } class SPIRVMessageConsumer { @@ -819,6 +849,12 @@ clc_dump_spirv(const struct clc_binary *spvbin, FILE *f) fwrite(out.c_str(), out.size(), 1, f); } +void +clc_free_spir_binary(struct clc_binary *spir) +{ + free(spir->data); +} + void clc_free_spirv_binary(struct clc_binary *spvbin) { diff --git a/src/microsoft/clc/clc_helpers.h b/src/microsoft/clc/clc_helpers.h index c85caac0b04..afb50f8e355 100644 --- a/src/microsoft/clc/clc_helpers.h +++ b/src/microsoft/clc/clc_helpers.h @@ -48,6 +48,16 @@ void clc_free_kernels_info(const struct clc_kernel_info *kernels, unsigned num_kernels); +int +clc_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir); + +int +clc_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv); + int clc_c_to_spirv(const struct clc_compile_args *args, const struct clc_logger *logger, @@ -61,6 +71,9 @@ clc_link_spirv_binaries(const struct clc_linker_args *args, void clc_dump_spirv(const struct clc_binary *spvbin, FILE *f); +void +clc_free_spir_binary(struct clc_binary *spir); + void clc_free_spirv_binary(struct clc_binary *spvbin); diff --git a/src/microsoft/clc/clon12compiler.def b/src/microsoft/clc/clon12compiler.def index 8a76cf87d83..f0c8a2268dd 100644 --- a/src/microsoft/clc/clon12compiler.def +++ b/src/microsoft/clc/clon12compiler.def @@ -4,6 +4,9 @@ EXPORTS clc_libclc_serialize clc_libclc_free_serialized clc_libclc_deserialize + clc_compile_c_to_spir + clc_free_spir + clc_compile_spir_to_spirv clc_free_spirv clc_compile_c_to_spirv clc_link_spirv