diff options
Diffstat (limited to 'src/database/database.cpp')
-rw-r--r-- | src/database/database.cpp | 71 |
1 files changed, 48 insertions, 23 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp index f1d1dc66..404be804 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -11,6 +11,8 @@ // // ================================================================================================= +#include <list> + #include "utilities/utilities.hpp" #include "database/database.hpp" @@ -28,31 +30,39 @@ #include "database/kernels/transpose.hpp" #include "database/kernels/padtranspose.hpp" #include "database/kernels/invert.hpp" +#include "database/apple_cpu_fallback.hpp" #include "database/kernel_selection.hpp" namespace clblast { // ================================================================================================= -// Initializes the database -const std::vector<const Database::DatabaseEntry*> Database::database = { - &database::XaxpyHalf, &database::XaxpySingle, &database::XaxpyDouble, &database::XaxpyComplexSingle, &database::XaxpyComplexDouble, - &database::XdotHalf, &database::XdotSingle, &database::XdotDouble, &database::XdotComplexSingle, &database::XdotComplexDouble, - &database::XgemvHalf, &database::XgemvSingle, &database::XgemvDouble, &database::XgemvComplexSingle, &database::XgemvComplexDouble, - &database::XgemvFastHalf, &database::XgemvFastSingle, &database::XgemvFastDouble, &database::XgemvFastComplexSingle, &database::XgemvFastComplexDouble, - &database::XgemvFastRotHalf, &database::XgemvFastRotSingle, &database::XgemvFastRotDouble, &database::XgemvFastRotComplexSingle, &database::XgemvFastRotComplexDouble, - &database::XgerHalf, &database::XgerSingle, &database::XgerDouble, &database::XgerComplexSingle, &database::XgerComplexDouble, - &database::XtrsvHalf, &database::XtrsvSingle, &database::XtrsvDouble, &database::XtrsvComplexSingle, &database::XtrsvComplexDouble, - &database::XgemmHalf, &database::XgemmSingle, &database::XgemmDouble, &database::XgemmComplexSingle, &database::XgemmComplexDouble, - &database::XgemmDirectHalf, &database::XgemmDirectSingle, &database::XgemmDirectDouble, &database::XgemmDirectComplexSingle, &database::XgemmDirectComplexDouble, - &database::CopyHalf, &database::CopySingle, &database::CopyDouble, &database::CopyComplexSingle, &database::CopyComplexDouble, - &database::PadHalf, &database::PadSingle, &database::PadDouble, &database::PadComplexSingle, &database::PadComplexDouble, - &database::TransposeHalf, &database::TransposeSingle, &database::TransposeDouble, &database::TransposeComplexSingle, &database::TransposeComplexDouble, - &database::PadtransposeHalf, &database::PadtransposeSingle, &database::PadtransposeDouble, &database::PadtransposeComplexSingle, &database::PadtransposeComplexDouble, - &database::InvertHalf, &database::InvertSingle, &database::InvertDouble, &database::InvertComplexSingle, &database::InvertComplexDouble, - &database::KernelSelectionHalf, &database::KernelSelectionSingle, &database::KernelSelectionDouble, &database::KernelSelectionComplexSingle, &database::KernelSelectionComplexDouble +// Initializes the databases +const std::vector<Database::DatabaseEntry> Database::database = std::vector<Database::DatabaseEntry>{ + database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble, + database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble, + database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble, + database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble, + database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble, + database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble, + database::XtrsvHalf, database::XtrsvSingle, database::XtrsvDouble, database::XtrsvComplexSingle, database::XtrsvComplexDouble, + database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble, + database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble, + database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble, + database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble, + database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble, + database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble, + database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble, + database::KernelSelectionHalf, database::KernelSelectionSingle, database::KernelSelectionDouble, database::KernelSelectionComplexSingle, database::KernelSelectionComplexDouble +}; +const std::vector<Database::DatabaseEntry> Database::apple_cpu_fallback = std::vector<Database::DatabaseEntry>{ + database::XaxpyApple, database::XdotApple, + database::XgemvApple, database::XgemvFastApple, database::XgemvFastRotApple, database::XgerApple, database::XtrsvApple, + database::XgemmApple, database::XgemmDirectApple, + database::CopyApple, database::PadApple, database::TransposeApple, database::PadtransposeApple, + database::InvertApple }; -// The OpenCL device vendors +// The default values const std::string Database::kDeviceVendorAll = "default"; // Alternative names for some OpenCL vendors @@ -68,7 +78,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{ // Constructor, computing device properties and populating the parameter-vector from the database. // This takes an optional overlay database in case of custom tuning or custom kernels. Database::Database(const Device &device, const std::string &kernel_name, - const Precision precision, const std::vector<const DatabaseEntry*> &overlay): + const Precision precision, const std::vector<DatabaseEntry> &overlay): parameters_(std::make_shared<Parameters>()) { // Finds information of the current device @@ -83,9 +93,23 @@ Database::Database(const Device &device, const std::string &kernel_name, } } + // Sets the databases to search through + auto databases = std::list<std::vector<DatabaseEntry>>{overlay, database}; + + // Special case: modifies the database if the device is a CPU with Apple OpenCL + #if defined(__APPLE__) || defined(__MACOSX) + if (device.Type() == "CPU") { + auto extensions = device.Capabilities(); + const auto is_apple = (extensions.find("cl_APPLE_SetMemObjectDestructor") == std::string::npos) ? false : true; + if (is_apple) { + databases.push_front(apple_cpu_fallback); + } + } + #endif + // Searches potentially multiple databases auto search_result = ParametersPtr{}; - for (auto &db: { overlay, database}) { + for (auto &db: databases) { search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db); if (search_result) { parameters_->insert(search_result->begin(), search_result->end()); @@ -124,15 +148,16 @@ Database::ParametersPtr Database::Search(const std::string &this_kernel, const std::string &this_vendor, const std::string &this_device, const Precision this_precision, - const std::vector<const DatabaseEntry*> &this_database) const { + const std::vector<DatabaseEntry> &this_database) const { // Selects the right kernel for (auto &db: this_database) { - if (db->kernel == this_kernel && db->precision == this_precision) { + if ((db.kernel == this_kernel) && + (db.precision == this_precision || db.precision == Precision::kAny)) { // Searches for the right vendor and device type, or selects the default if unavailable. This // assumes that the default vendor / device type is last in the database. - for (auto &vendor: db->vendors) { + for (auto &vendor: db.vendors) { if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) && (vendor.type == this_type || vendor.type == database::kDeviceTypeAll)) { |