diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-02-18 12:34:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-18 12:34:38 +0100 |
commit | 0ea30263acc9f88544e1e7943adde6c58668dcce (patch) | |
tree | 3c91b7fb890f81aa774463b39baa09ce49e1face /src/database | |
parent | dc93523204ebe8562145997673f25f8e59f9d2f5 (diff) | |
parent | 7b2170818f11e0714c8b08aa1dd5b32bfef3f4b6 (diff) |
Merge pull request #137 from CNugteren/custom_parameters
API to override tuning parameters
Diffstat (limited to 'src/database')
-rw-r--r-- | src/database/database.cpp | 32 | ||||
-rw-r--r-- | src/database/database.hpp | 33 |
2 files changed, 50 insertions, 15 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp index c1cb9d56..02d0b139 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -63,7 +63,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{ // Constructor, computing device properties and populating the parameter-vector from the database. // This takes an optional overlay database in case of custom tuning or custom kernels. -Database::Database(const Device &device, const std::vector<std::string> &kernels, +Database::Database(const Device &device, const std::string &kernel_name, const Precision precision, const std::vector<const DatabaseEntry*> &overlay): parameters_(std::make_shared<Parameters>()) { @@ -79,20 +79,17 @@ Database::Database(const Device &device, const std::vector<std::string> &kernels } } - // Iterates over all kernels to include, and retrieves the parameters for each of them - for (auto &kernel: kernels) { - auto search_result = ParametersPtr{}; - - for (auto &db: { database, overlay}) { - search_result = Search(kernel, device_type, device_vendor, device_name, precision, db); - if (search_result) { - parameters_->insert(search_result->begin(), search_result->end()); - break; - } + // Searches potentially multiple databases + auto search_result = ParametersPtr{}; + for (auto &db: { overlay, database}) { + search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db); + if (search_result) { + parameters_->insert(search_result->begin(), search_result->end()); + break; } - - if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); } } + + if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); } } // ================================================================================================= @@ -106,6 +103,15 @@ std::string Database::GetDefines() const { return defines; } +// Retrieves the names of all the parameters +std::vector<std::string> Database::GetParameterNames() const { + auto parameter_names = std::vector<std::string>(); + for (auto ¶meter: *parameters_) { + parameter_names.push_back(parameter.first); + } + return parameter_names; +} + // ================================================================================================= // Searches a particular database for the right kernel and precision diff --git a/src/database/database.hpp b/src/database/database.hpp index 87c12293..b34e0d8a 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -75,15 +75,19 @@ class Database { Database() = default; // The constructor with a user-provided database overlay (potentially an empty vector) - explicit Database(const Device &device, const std::vector<std::string> &routines, + explicit Database(const Device &device, const std::string &kernel_name, const Precision precision, const std::vector<const DatabaseEntry*> &overlay); // Accessor of values by key - size_t operator[](const std::string key) const { return parameters_->find(key)->second; } + size_t operator[](const std::string &key) const { return parameters_->find(key)->second; } + bool exists(const std::string &key) const { return (parameters_->count(key) == 1); } // Obtain a list of OpenCL pre-processor defines based on the parameters std::string GetDefines() const; + // Retrieves the names of all the parameters + std::vector<std::string> GetParameterNames() const; + private: // Search method for a specified database, returning pointer (possibly a nullptr) ParametersPtr Search(const std::string &this_kernel, const std::string &this_type, @@ -96,6 +100,31 @@ class Database { }; // ================================================================================================= + +// Multiple databases together in a map +class Databases { + public: + + explicit Databases(const std::vector<std::string> &kernel_names): kernel_names_(kernel_names) { } + + // Database accessor + Database& operator()(const std::string &kernel_name) { return databases_[kernel_name]; } + + // Retrieves a parameter from the database + size_t operator[](const std::string &key) const { + for (const auto &kernel_name : kernel_names_) { + const auto &kernel_db = databases_.find(kernel_name)->second; + if (kernel_db.exists(key)) { return kernel_db[key]; } + } + throw RuntimeErrorCode(StatusCode::kDatabaseError); + } + + private: + const std::vector<std::string> kernel_names_; + std::unordered_map<std::string, Database> databases_; +}; + +// ================================================================================================= } // namespace clblast // CLBLAST_DATABASE_H_ |