summaryrefslogtreecommitdiff
path: root/src/database/database.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/database.cpp')
-rw-r--r--src/database/database.cpp55
1 files changed, 33 insertions, 22 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 6ec93731..47f1da16 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -42,9 +42,10 @@ const std::vector<Database::DatabaseEntry> Database::database = {
// =================================================================================================
-// Constructor, computing device properties and populating the parameter-vector from the database
+// Constructor, computing device properties and populating the parameter-vector from the database.
+// This takes an optional overlay database in case of custom tuning or custom kernels.
Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
- const Precision precision):
+ const Precision precision, const std::vector<DatabaseEntry> &overlay):
parameters_{} {
// Finds information of the current device
@@ -53,10 +54,26 @@ Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
auto device_vendor = device.Vendor();
auto device_name = device.Name();
+ // Set the short vendor name
+ for (auto &combination : kVendorNames) {
+ if (device_vendor == combination.first) {
+ device_vendor = combination.second;
+ }
+ }
+
// Iterates over all kernels to include, and retrieves the parameters for each of them
for (auto &kernel: kernels) {
- auto search_result = Search(kernel, device_type, device_vendor, device_name, precision);
- parameters_.insert(search_result.begin(), search_result.end());
+ auto search_result = ParametersPtr{};
+
+ for (auto db: { &overlay, &database }) {
+ search_result = Search(kernel, device_type, device_vendor, device_name, precision, *db);
+ if (search_result) {
+ parameters_.insert(search_result->begin(), search_result->end());
+ break;
+ }
+ }
+
+ if (!search_result) { throw std::runtime_error("Database error, could not find a suitable entry"); }
}
}
@@ -73,28 +90,22 @@ std::string Database::GetDefines() const {
// =================================================================================================
-// Searches the database for the right kernel and precision
-Database::Parameters Database::Search(const std::string &this_kernel,
- const std::string &this_type,
- const std::string &this_vendor,
- const std::string &this_device,
- const Precision this_precision) const {
- // Set the short vendor name
- auto this_short_vendor = this_vendor;
- for (auto &combination : kVendorNames) {
- if (this_vendor == combination.first) {
- this_short_vendor = combination.second;
- }
- }
+// Searches a particular database for the right kernel and precision
+Database::ParametersPtr Database::Search(const std::string &this_kernel,
+ const std::string &this_type,
+ const std::string &this_vendor,
+ const std::string &this_device,
+ const Precision this_precision,
+ const std::vector<DatabaseEntry> &this_database) const {
// Selects the right kernel
- for (auto &db: database) {
+ for (auto &db: this_database) {
if (db.kernel == this_kernel && db.precision == this_precision) {
// Searches for the right vendor and device type, or selects the default if unavailable. This
// assumes that the default vendor / device type is last in the database.
for (auto &vendor: db.vendors) {
- if ((vendor.name == this_short_vendor || vendor.name == kDeviceVendorAll) &&
+ if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
(vendor.type == this_type || vendor.type == kDeviceTypeAll)) {
// Searches for the right device. If the current device is unavailable, selects the vendor
@@ -104,7 +115,7 @@ Database::Parameters Database::Search(const std::string &this_kernel,
if (device.name == this_device || device.name == "default") {
// Sets the parameters accordingly
- return device.parameters;
+ return &device.parameters;
}
}
}
@@ -112,8 +123,8 @@ Database::Parameters Database::Search(const std::string &this_kernel,
}
}
- // If we reached this point, something is wrong
- throw std::runtime_error("Database error, could not find a suitable entry");
+ // If we reached this point, the entry was not found in this database
+ return nullptr;
}
// =================================================================================================