summaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-02-12 12:02:39 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-02-12 12:02:39 +0100
commit345a5feb9a18641ceffd7ce5e0cb9387686cf32c (patch)
treecdae4f7fd42f4b23349867243434a1d648ed3e41 /src/database
parentfaa842b927ede6df1763607e3732151162875d73 (diff)
Split the database into several smaller cached per-kernel databases (in preparation of per-kernel database overrides)
Diffstat (limited to 'src/database')
-rw-r--r--src/database/database.cpp23
-rw-r--r--src/database/database.hpp30
2 files changed, 38 insertions, 15 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp
index c1cb9d56..8019d558 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -63,7 +63,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{
// Constructor, computing device properties and populating the parameter-vector from the database.
// This takes an optional overlay database in case of custom tuning or custom kernels.
-Database::Database(const Device &device, const std::vector<std::string> &kernels,
+Database::Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<const DatabaseEntry*> &overlay):
parameters_(std::make_shared<Parameters>()) {
@@ -79,20 +79,17 @@ Database::Database(const Device &device, const std::vector<std::string> &kernels
}
}
- // Iterates over all kernels to include, and retrieves the parameters for each of them
- for (auto &kernel: kernels) {
- auto search_result = ParametersPtr{};
-
- for (auto &db: { database, overlay}) {
- search_result = Search(kernel, device_type, device_vendor, device_name, precision, db);
- if (search_result) {
- parameters_->insert(search_result->begin(), search_result->end());
- break;
- }
+ // Searches potentially multiple databases
+ auto search_result = ParametersPtr{};
+ for (auto &db: { overlay, database}) {
+ search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db);
+ if (search_result) {
+ parameters_->insert(search_result->begin(), search_result->end());
+ break;
}
-
- if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
}
+
+ if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
}
// =================================================================================================
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 87c12293..b6760ec3 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -75,11 +75,12 @@ class Database {
Database() = default;
// The constructor with a user-provided database overlay (potentially an empty vector)
- explicit Database(const Device &device, const std::vector<std::string> &routines,
+ explicit Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<const DatabaseEntry*> &overlay);
// Accessor of values by key
- size_t operator[](const std::string key) const { return parameters_->find(key)->second; }
+ size_t operator[](const std::string &key) const { return parameters_->find(key)->second; }
+ bool exists(const std::string &key) const { return (parameters_->count(key) == 1); }
// Obtain a list of OpenCL pre-processor defines based on the parameters
std::string GetDefines() const;
@@ -96,6 +97,31 @@ class Database {
};
// =================================================================================================
+
+// Multiple databases together in a map
+class Databases {
+ public:
+
+ explicit Databases(const std::vector<std::string> &kernel_names): kernel_names_(kernel_names) { }
+
+ // Database accessor
+ Database& operator()(const std::string &kernel_name) { return databases_[kernel_name]; }
+
+ // Retrieves a parameter from the database
+ size_t operator[](const std::string &key) const {
+ for (const auto &kernel_name : kernel_names_) {
+ const auto &kernel_db = databases_.find(kernel_name)->second;
+ if (kernel_db.exists(key)) { return kernel_db[key]; }
+ }
+ throw RuntimeErrorCode(StatusCode::kDatabaseError);
+ }
+
+ private:
+ const std::vector<std::string> kernel_names_;
+ std::unordered_map<std::string, Database> databases_;
+};
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_DATABASE_H_