summaryrefslogtreecommitdiff
path: root/src/database/database.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/database.cpp')
-rw-r--r--src/database/database.cpp59
1 files changed, 37 insertions, 22 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 6ec93731..34c44a29 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -17,6 +17,8 @@
#include "database/kernels/xaxpy.hpp"
#include "database/kernels/xdot.hpp"
#include "database/kernels/xgemv.hpp"
+#include "database/kernels/xgemv_fast.hpp"
+#include "database/kernels/xgemv_fast_rot.hpp"
#include "database/kernels/xger.hpp"
#include "database/kernels/xgemm.hpp"
#include "database/kernels/copy.hpp"
@@ -32,6 +34,8 @@ const std::vector<Database::DatabaseEntry> Database::database = {
XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble,
XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble,
XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble,
+ XgemvFastHalf, XgemvFastSingle, XgemvFastDouble, XgemvFastComplexSingle, XgemvFastComplexDouble,
+ XgemvFastRotHalf, XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble,
XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble,
XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble,
CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble,
@@ -42,9 +46,10 @@ const std::vector<Database::DatabaseEntry> Database::database = {
// =================================================================================================
-// Constructor, computing device properties and populating the parameter-vector from the database
+// Constructor, computing device properties and populating the parameter-vector from the database.
+// This takes an optional overlay database in case of custom tuning or custom kernels.
Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
- const Precision precision):
+ const Precision precision, const std::vector<DatabaseEntry> &overlay):
parameters_{} {
// Finds information of the current device
@@ -53,10 +58,26 @@ Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
auto device_vendor = device.Vendor();
auto device_name = device.Name();
+ // Set the short vendor name
+ for (auto &combination : kVendorNames) {
+ if (device_vendor == combination.first) {
+ device_vendor = combination.second;
+ }
+ }
+
// Iterates over all kernels to include, and retrieves the parameters for each of them
for (auto &kernel: kernels) {
- auto search_result = Search(kernel, device_type, device_vendor, device_name, precision);
- parameters_.insert(search_result.begin(), search_result.end());
+ auto search_result = ParametersPtr{};
+
+ for (auto db: { &overlay, &database }) {
+ search_result = Search(kernel, device_type, device_vendor, device_name, precision, *db);
+ if (search_result) {
+ parameters_.insert(search_result->begin(), search_result->end());
+ break;
+ }
+ }
+
+ if (!search_result) { throw std::runtime_error("Database error, could not find a suitable entry"); }
}
}
@@ -73,28 +94,22 @@ std::string Database::GetDefines() const {
// =================================================================================================
-// Searches the database for the right kernel and precision
-Database::Parameters Database::Search(const std::string &this_kernel,
- const std::string &this_type,
- const std::string &this_vendor,
- const std::string &this_device,
- const Precision this_precision) const {
- // Set the short vendor name
- auto this_short_vendor = this_vendor;
- for (auto &combination : kVendorNames) {
- if (this_vendor == combination.first) {
- this_short_vendor = combination.second;
- }
- }
+// Searches a particular database for the right kernel and precision
+Database::ParametersPtr Database::Search(const std::string &this_kernel,
+ const std::string &this_type,
+ const std::string &this_vendor,
+ const std::string &this_device,
+ const Precision this_precision,
+ const std::vector<DatabaseEntry> &this_database) const {
// Selects the right kernel
- for (auto &db: database) {
+ for (auto &db: this_database) {
if (db.kernel == this_kernel && db.precision == this_precision) {
// Searches for the right vendor and device type, or selects the default if unavailable. This
// assumes that the default vendor / device type is last in the database.
for (auto &vendor: db.vendors) {
- if ((vendor.name == this_short_vendor || vendor.name == kDeviceVendorAll) &&
+ if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
(vendor.type == this_type || vendor.type == kDeviceTypeAll)) {
// Searches for the right device. If the current device is unavailable, selects the vendor
@@ -104,7 +119,7 @@ Database::Parameters Database::Search(const std::string &this_kernel,
if (device.name == this_device || device.name == "default") {
// Sets the parameters accordingly
- return device.parameters;
+ return &device.parameters;
}
}
}
@@ -112,8 +127,8 @@ Database::Parameters Database::Search(const std::string &this_kernel,
}
}
- // If we reached this point, something is wrong
- throw std::runtime_error("Database error, could not find a suitable entry");
+ // If we reached this point, the entry was not found in this database
+ return nullptr;
}
// =================================================================================================