summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/database/database.cpp45
-rw-r--r--src/database/database.hpp10
-rw-r--r--src/routine.cpp5
-rw-r--r--src/routine.hpp3
4 files changed, 39 insertions, 24 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 6ec93731..ea1557b9 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -44,7 +44,7 @@ const std::vector<Database::DatabaseEntry> Database::database = {
// Constructor, computing device properties and populating the parameter-vector from the database
Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
- const Precision precision):
+ const Precision precision, const std::vector<DatabaseEntry> &overlay):
parameters_{} {
// Finds information of the current device
@@ -53,10 +53,23 @@ Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
auto device_vendor = device.Vendor();
auto device_name = device.Name();
+ // Set the short vendor name
+ for (auto &combination : kVendorNames) {
+ if (device_vendor == combination.first) {
+ device_vendor = combination.second;
+ }
+ }
+
// Iterates over all kernels to include, and retrieves the parameters for each of them
for (auto &kernel: kernels) {
- auto search_result = Search(kernel, device_type, device_vendor, device_name, precision);
- parameters_.insert(search_result.begin(), search_result.end());
+ auto search_result = ParametersPtr{};
+
+ for (auto db: { &overlay, &database }) {
+ search_result = Search(kernel, device_type, device_vendor, device_name, precision, *db);
+ if (search_result) { parameters_.insert(search_result->begin(), search_result->end()); break; }
+ }
+
+ if (!search_result) { throw std::runtime_error("Database error, could not find a suitable entry"); }
}
}
@@ -74,27 +87,21 @@ std::string Database::GetDefines() const {
// =================================================================================================
// Searches the database for the right kernel and precision
-Database::Parameters Database::Search(const std::string &this_kernel,
- const std::string &this_type,
- const std::string &this_vendor,
- const std::string &this_device,
- const Precision this_precision) const {
- // Set the short vendor name
- auto this_short_vendor = this_vendor;
- for (auto &combination : kVendorNames) {
- if (this_vendor == combination.first) {
- this_short_vendor = combination.second;
- }
- }
+Database::ParametersPtr Database::Search(const std::string &this_kernel,
+ const std::string &this_type,
+ const std::string &this_vendor,
+ const std::string &this_device,
+ const Precision this_precision,
+ const std::vector<DatabaseEntry> &this_database) const {
// Selects the right kernel
- for (auto &db: database) {
+ for (auto &db: this_database) {
if (db.kernel == this_kernel && db.precision == this_precision) {
// Searches for the right vendor and device type, or selects the default if unavailable. This
// assumes that the default vendor / device type is last in the database.
for (auto &vendor: db.vendors) {
- if ((vendor.name == this_short_vendor || vendor.name == kDeviceVendorAll) &&
+ if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
(vendor.type == this_type || vendor.type == kDeviceTypeAll)) {
// Searches for the right device. If the current device is unavailable, selects the vendor
@@ -104,7 +111,7 @@ Database::Parameters Database::Search(const std::string &this_kernel,
if (device.name == this_device || device.name == "default") {
// Sets the parameters accordingly
- return device.parameters;
+ return &device.parameters;
}
}
}
@@ -113,7 +120,7 @@ Database::Parameters Database::Search(const std::string &this_kernel,
}
// If we reached this point, something is wrong
- throw std::runtime_error("Database error, could not find a suitable entry");
+ return nullptr;
}
// =================================================================================================
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 0987cbed..5a61fad9 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -32,6 +32,7 @@ class Database {
// Type alias for the database parameters
using Parameters = std::unordered_map<std::string,size_t>;
+ using ParametersPtr = const Parameters*;
// Structures for content inside the database
struct DatabaseDevice {
@@ -78,9 +79,9 @@ class Database {
static const DatabaseEntry PadtransposeHalf, PadtransposeSingle, PadtransposeDouble, PadtransposeComplexSingle, PadtransposeComplexDouble;
static const std::vector<DatabaseEntry> database;
- // The constructor
+ // The constructor with a user-provided database overlay
explicit Database(const Queue &queue, const std::vector<std::string> &routines,
- const Precision precision);
+ const Precision precision, const std::vector<DatabaseEntry> &overlay);
// Accessor of values by key
size_t operator[](const std::string key) const { return parameters_.find(key)->second; }
@@ -93,6 +94,11 @@ class Database {
const std::string &this_vendor, const std::string &this_device,
const Precision this_precision) const;
+ // Alternate search method in a specified database, returning pointer (possibly NULL)
+ ParametersPtr Search(const std::string &this_kernel, const std::string &this_type,
+ const std::string &this_vendor, const std::string &this_device,
+ const Precision this_precision, const std::vector<DatabaseEntry> &db) const;
+
// Found parameters suitable for this device/kernel
Parameters parameters_;
};
diff --git a/src/routine.cpp b/src/routine.cpp
index 3c3343da..189ae190 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -22,7 +22,8 @@ namespace clblast {
// Constructor: not much here, because no status codes can be returned
Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision):
+ const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<Database::DatabaseEntry> &userDatabase):
precision_(precision),
routine_name_(name),
queue_(queue),
@@ -30,7 +31,7 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
device_name_(device_.Name()),
- db_(queue_, routines, precision_) {
+ db_(queue_, routines, precision_, userDatabase) {
}
// =================================================================================================
diff --git a/src/routine.hpp b/src/routine.hpp
index 54b5779f..21506e7b 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -34,7 +34,8 @@ class Routine {
// Base class constructor
explicit Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision);
+ const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<Database::DatabaseEntry> &userDatabase = {});
// Set-up phase of the kernel
StatusCode SetUp();