diff --git a/src/gallium/frontends/clover/llvm/codegen/common.cpp b/src/gallium/frontends/clover/llvm/codegen/common.cpp index c9f303a9924..6bf15d8f79d 100644 --- a/src/gallium/frontends/clover/llvm/codegen/common.cpp +++ b/src/gallium/frontends/clover/llvm/codegen/common.cpp @@ -148,6 +148,47 @@ namespace { return detokenize(attributes, " "); } + // Parse the type which are pointers to CL vector types with no prefix. + // so e.g. char/uchar, short/ushort, int/uint, long/ulong + // half/float/double, followed by the vector length, followed by *. + // uint8 is 8x32-bit integer, short4 is 4x16-bit integer etc. + // Since this is a pointer only path, assert the * is on the end. + ::llvm::Type * + ptr_arg_to_llvm_type(const Module &mod, std::string type_name) { + int len = type_name.length(); + assert (type_name[len-1] == '*'); + ::llvm::Type *base_type = NULL; + if (type_name.find("void") != std::string::npos) + base_type = ::llvm::Type::getVoidTy(mod.getContext()); + else if (type_name.find("char") != std::string::npos) + base_type = ::llvm::Type::getInt8Ty(mod.getContext()); + else if (type_name.find("short") != std::string::npos) + base_type = ::llvm::Type::getInt16Ty(mod.getContext()); + else if (type_name.find("int") != std::string::npos) + base_type = ::llvm::Type::getInt32Ty(mod.getContext()); + else if (type_name.find("long") != std::string::npos) + base_type = ::llvm::Type::getInt64Ty(mod.getContext()); + else if (type_name.find("half") != std::string::npos) + base_type = ::llvm::Type::getHalfTy(mod.getContext()); + else if (type_name.find("float") != std::string::npos) + base_type = ::llvm::Type::getFloatTy(mod.getContext()); + else if (type_name.find("double") != std::string::npos) + base_type = ::llvm::Type::getDoubleTy(mod.getContext()); + + assert(base_type); + if (type_name.find("2") != std::string::npos) + base_type = ::llvm::FixedVectorType::get(base_type, 2); + else if (type_name.find("3") != std::string::npos) + base_type = ::llvm::FixedVectorType::get(base_type, 3); + else if (type_name.find("4") != std::string::npos) + base_type = ::llvm::FixedVectorType::get(base_type, 4); + else if (type_name.find("8") != std::string::npos) + base_type = ::llvm::FixedVectorType::get(base_type, 8); + else if (type_name.find("16") != std::string::npos) + base_type = ::llvm::FixedVectorType::get(base_type, 16); + return base_type; + } + std::vector make_kernel_args(const Module &mod, const std::string &kernel_name, const clang::CompilerInstance &c) { @@ -204,7 +245,7 @@ namespace { // Other types. const auto actual_type = isa< ::llvm::PointerType>(arg_type) && arg.hasByValAttr() ? - cast< ::llvm::PointerType>(arg_type)->getPointerElementType() : arg_type; + ptr_arg_to_llvm_type(mod, type_name) : arg_type; if (actual_type->isPointerTy()) { const unsigned address_space = @@ -214,11 +255,11 @@ namespace { const auto offset = static_cast(clang::LangAS::opencl_local); if (address_space == map[offset]) { - const auto pointee_type = cast< - ::llvm::PointerType>(actual_type)->getPointerElementType(); + const auto pointee_type = ptr_arg_to_llvm_type(mod, type_name); + args.emplace_back(binary::argument::local, arg_api_size, target_size, - dl.getABITypeAlignment(pointee_type), + (pointee_type->isVoidTy()) ? 8 : dl.getABITypeAlignment(pointee_type), binary::argument::zero_ext); } else { // XXX: Correctly handle constant address space. There is no