summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-09-04 17:39:57 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-09-04 17:39:57 +0200
commit28462aa05068522733dcdc812795ec235c40f7e1 (patch)
tree0ca153975169c1afae2825cfacca4e130b613a91
parent297159d5b99f33f4e49cea238e66f1a1f05389a3 (diff)
Removed an assumption that the 'default' tuning parameters have to be stored last; this is no longer needed
-rw-r--r--CHANGELOG1
-rw-r--r--src/database/database.cpp67
-rw-r--r--src/database/database.hpp10
3 files changed, 52 insertions, 26 deletions
diff --git a/CHANGELOG b/CHANGELOG
index c3519778..d7d70b7a 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,6 +1,7 @@
Development (next version)
- The tuners can now use particle-swarm optimisation to search more efficiently (thanks to 'mcian')
+- Various minor fixes and enhancements
- Added non-BLAS routines:
* SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL (im2col transform as used to express convolution as GEMM)
diff --git a/src/database/database.cpp b/src/database/database.cpp
index fe543122..79c2ea03 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -154,31 +154,13 @@ Database::Parameters Database::Search(const std::string &this_kernel,
for (auto &db: this_database) {
if ((db.kernel == this_kernel) &&
(db.precision == this_precision || db.precision == Precision::kAny)) {
- const auto parameter_names = db.parameter_names;
-
- // Searches for the right vendor and device type, or selects the default if unavailable. This
- // assumes that the default vendor / device type is last in the database.
- for (auto &vendor: db.vendors) {
- if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
- (vendor.type == this_type || vendor.type == database::kDeviceTypeAll)) {
-
- // Searches for the right device. If the current device is unavailable, selects the vendor
- // default parameters. This assumes the default is last in the database.
- for (auto &device: vendor.devices) {
-
- if (device.name == this_device || device.name == "default") {
-
- // Sets the parameters accordingly
- auto parameters = Parameters();
- if (parameter_names.size() != device.parameters.size()) { return Parameters(); } // ERROR
- for (auto i = size_t{0}; i < parameter_names.size(); ++i) {
- parameters[parameter_names[i]] = device.parameters[i];
- }
- return parameters;
- }
- }
- }
- }
+
+ // Searches for the right vendor and device type, or selects the default if unavailable
+ const auto parameters = SearchVendorAndType(this_vendor, this_type, this_device,
+ db.vendors, db.parameter_names);
+ if (parameters.size() != 0) { return parameters; }
+ return SearchVendorAndType(kDeviceVendorAll, database::kDeviceTypeAll, this_device,
+ db.vendors, db.parameter_names);
}
}
@@ -186,5 +168,40 @@ Database::Parameters Database::Search(const std::string &this_kernel,
return Parameters();
}
+Database::Parameters Database::SearchVendorAndType(const std::string &target_vendor,
+ const std::string &target_type,
+ const std::string &this_device,
+ const std::vector<DatabaseVendor> &vendors,
+ const std::vector<std::string> &parameter_names) const {
+ for (auto &vendor: vendors) {
+ if ((vendor.name == target_vendor) && (vendor.type == target_type)) {
+
+ // Searches the device; if unavailable, returns the vendor's default parameters
+ const auto parameters = SearchDevice(this_device, vendor.devices, parameter_names);
+ if (parameters.size() != 0) { return parameters; }
+ return SearchDevice("default", vendor.devices, parameter_names);
+ }
+ }
+ return Parameters();
+}
+
+Database::Parameters Database::SearchDevice(const std::string &target_device,
+ const std::vector<DatabaseDevice> &devices,
+ const std::vector<std::string> &parameter_names) const {
+ for (auto &device: devices) {
+ if (device.name == target_device) {
+
+ // Sets the parameters accordingly
+ auto parameters = Parameters();
+ if (parameter_names.size() != device.parameters.size()) { return Parameters(); } // ERROR
+ for (auto i = size_t{0}; i < parameter_names.size(); ++i) {
+ parameters[parameter_names[i]] = device.parameters[i];
+ }
+ return parameters;
+ }
+ }
+ return Parameters();
+}
+
// =================================================================================================
} // namespace clblast
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 3f984439..b652164c 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -92,11 +92,19 @@ class Database {
std::vector<std::string> GetParameterNames() const;
private:
- // Search method for a specified database, returning pointer (possibly a nullptr)
+ // Search method functions, returning a set of parameters (possibly empty)
Parameters Search(const std::string &this_kernel, const std::string &this_type,
const std::string &this_vendor, const std::string &this_device,
const Precision this_precision,
const std::vector<DatabaseEntry> &db) const;
+ Parameters SearchDevice(const std::string &target_device,
+ const std::vector<DatabaseDevice> &devices,
+ const std::vector<std::string> &parameter_names) const;
+ Parameters SearchVendorAndType(const std::string &target_vendor,
+ const std::string &target_type,
+ const std::string &this_device,
+ const std::vector<DatabaseVendor> &vendors,
+ const std::vector<std::string> &parameter_names) const;
// Found parameters suitable for this device/kernel
std::shared_ptr<Parameters> parameters_;