summaryrefslogtreecommitdiff
path: root/src/database/database.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/database.cpp')
-rw-r--r--src/database/database.cpp71
1 files changed, 48 insertions, 23 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp
index f1d1dc66..404be804 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -11,6 +11,8 @@
//
// =================================================================================================
+#include <list>
+
#include "utilities/utilities.hpp"
#include "database/database.hpp"
@@ -28,31 +30,39 @@
#include "database/kernels/transpose.hpp"
#include "database/kernels/padtranspose.hpp"
#include "database/kernels/invert.hpp"
+#include "database/apple_cpu_fallback.hpp"
#include "database/kernel_selection.hpp"
namespace clblast {
// =================================================================================================
-// Initializes the database
-const std::vector<const Database::DatabaseEntry*> Database::database = {
- &database::XaxpyHalf, &database::XaxpySingle, &database::XaxpyDouble, &database::XaxpyComplexSingle, &database::XaxpyComplexDouble,
- &database::XdotHalf, &database::XdotSingle, &database::XdotDouble, &database::XdotComplexSingle, &database::XdotComplexDouble,
- &database::XgemvHalf, &database::XgemvSingle, &database::XgemvDouble, &database::XgemvComplexSingle, &database::XgemvComplexDouble,
- &database::XgemvFastHalf, &database::XgemvFastSingle, &database::XgemvFastDouble, &database::XgemvFastComplexSingle, &database::XgemvFastComplexDouble,
- &database::XgemvFastRotHalf, &database::XgemvFastRotSingle, &database::XgemvFastRotDouble, &database::XgemvFastRotComplexSingle, &database::XgemvFastRotComplexDouble,
- &database::XgerHalf, &database::XgerSingle, &database::XgerDouble, &database::XgerComplexSingle, &database::XgerComplexDouble,
- &database::XtrsvHalf, &database::XtrsvSingle, &database::XtrsvDouble, &database::XtrsvComplexSingle, &database::XtrsvComplexDouble,
- &database::XgemmHalf, &database::XgemmSingle, &database::XgemmDouble, &database::XgemmComplexSingle, &database::XgemmComplexDouble,
- &database::XgemmDirectHalf, &database::XgemmDirectSingle, &database::XgemmDirectDouble, &database::XgemmDirectComplexSingle, &database::XgemmDirectComplexDouble,
- &database::CopyHalf, &database::CopySingle, &database::CopyDouble, &database::CopyComplexSingle, &database::CopyComplexDouble,
- &database::PadHalf, &database::PadSingle, &database::PadDouble, &database::PadComplexSingle, &database::PadComplexDouble,
- &database::TransposeHalf, &database::TransposeSingle, &database::TransposeDouble, &database::TransposeComplexSingle, &database::TransposeComplexDouble,
- &database::PadtransposeHalf, &database::PadtransposeSingle, &database::PadtransposeDouble, &database::PadtransposeComplexSingle, &database::PadtransposeComplexDouble,
- &database::InvertHalf, &database::InvertSingle, &database::InvertDouble, &database::InvertComplexSingle, &database::InvertComplexDouble,
- &database::KernelSelectionHalf, &database::KernelSelectionSingle, &database::KernelSelectionDouble, &database::KernelSelectionComplexSingle, &database::KernelSelectionComplexDouble
+// Initializes the databases
+const std::vector<Database::DatabaseEntry> Database::database = std::vector<Database::DatabaseEntry>{
+ database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble,
+ database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble,
+ database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble,
+ database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble,
+ database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble,
+ database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble,
+ database::XtrsvHalf, database::XtrsvSingle, database::XtrsvDouble, database::XtrsvComplexSingle, database::XtrsvComplexDouble,
+ database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble,
+ database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble,
+ database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble,
+ database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble,
+ database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble,
+ database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble,
+ database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble,
+ database::KernelSelectionHalf, database::KernelSelectionSingle, database::KernelSelectionDouble, database::KernelSelectionComplexSingle, database::KernelSelectionComplexDouble
+};
+const std::vector<Database::DatabaseEntry> Database::apple_cpu_fallback = std::vector<Database::DatabaseEntry>{
+ database::XaxpyApple, database::XdotApple,
+ database::XgemvApple, database::XgemvFastApple, database::XgemvFastRotApple, database::XgerApple, database::XtrsvApple,
+ database::XgemmApple, database::XgemmDirectApple,
+ database::CopyApple, database::PadApple, database::TransposeApple, database::PadtransposeApple,
+ database::InvertApple
};
-// The OpenCL device vendors
+// The default values
const std::string Database::kDeviceVendorAll = "default";
// Alternative names for some OpenCL vendors
@@ -68,7 +78,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{
// Constructor, computing device properties and populating the parameter-vector from the database.
// This takes an optional overlay database in case of custom tuning or custom kernels.
Database::Database(const Device &device, const std::string &kernel_name,
- const Precision precision, const std::vector<const DatabaseEntry*> &overlay):
+ const Precision precision, const std::vector<DatabaseEntry> &overlay):
parameters_(std::make_shared<Parameters>()) {
// Finds information of the current device
@@ -83,9 +93,23 @@ Database::Database(const Device &device, const std::string &kernel_name,
}
}
+ // Sets the databases to search through
+ auto databases = std::list<std::vector<DatabaseEntry>>{overlay, database};
+
+ // Special case: modifies the database if the device is a CPU with Apple OpenCL
+ #if defined(__APPLE__) || defined(__MACOSX)
+ if (device.Type() == "CPU") {
+ auto extensions = device.Capabilities();
+ const auto is_apple = (extensions.find("cl_APPLE_SetMemObjectDestructor") == std::string::npos) ? false : true;
+ if (is_apple) {
+ databases.push_front(apple_cpu_fallback);
+ }
+ }
+ #endif
+
// Searches potentially multiple databases
auto search_result = ParametersPtr{};
- for (auto &db: { overlay, database}) {
+ for (auto &db: databases) {
search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db);
if (search_result) {
parameters_->insert(search_result->begin(), search_result->end());
@@ -124,15 +148,16 @@ Database::ParametersPtr Database::Search(const std::string &this_kernel,
const std::string &this_vendor,
const std::string &this_device,
const Precision this_precision,
- const std::vector<const DatabaseEntry*> &this_database) const {
+ const std::vector<DatabaseEntry> &this_database) const {
// Selects the right kernel
for (auto &db: this_database) {
- if (db->kernel == this_kernel && db->precision == this_precision) {
+ if ((db.kernel == this_kernel) &&
+ (db.precision == this_precision || db.precision == Precision::kAny)) {
// Searches for the right vendor and device type, or selects the default if unavailable. This
// assumes that the default vendor / device type is last in the database.
- for (auto &vendor: db->vendors) {
+ for (auto &vendor: db.vendors) {
if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
(vendor.type == this_type || vendor.type == database::kDeviceTypeAll)) {