summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-09-24 15:44:14 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-09-24 15:44:14 +0200
commited980a1df1482e188e1d579b5025e7c86a5ec65c (patch)
tree91022782e01a471dc50e001ad72d2d24a47db035 /src/clblast.cpp
parent255f09843cb4a0a10e04a9581184750da0d36593 (diff)
Updated database override function to work with the new database storage format
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp28
1 files changed, 11 insertions, 17 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 3983e5fc..bb338503 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2499,26 +2499,20 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
auto in_cache = false;
const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
if (!in_cache) { return StatusCode::kInvalidOverrideKernel; }
- for (const auto &current_param : current_database.GetParameterNames()) {
- if (parameters.find(current_param) == parameters.end()) {
- return StatusCode::kMissingOverrideParameter;
- }
- }
-
- // Clears the existing program & binary cache for routines with the target kernel
- const auto routine_names = Routine::routines_by_kernel.at(kernel_name);
- for (const auto &routine_name : routine_names) {
- ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, device, precision, routine_name});
- BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name});
+ const auto current_parameter_names = current_database.GetParameterNames();
+ if (current_parameter_names.size() != parameters.size()) {
+ return StatusCode::kMissingOverrideParameter;
}
- // Retrieves the names and values separately
+ // Retrieves the names and values separately and in the same order as the existing database
auto parameter_values = database::Params{0};
- auto parameter_names = std::vector<std::string>();
auto i = size_t{0};
- for (const auto &parameter : parameters) {
- parameter_values[i] = parameter.second;
- parameter_names.push_back(parameter.first);
+ for (const auto &current_param : current_parameter_names) {
+ if (parameters.find(current_param) == parameters.end()) {
+ return StatusCode::kMissingOverrideParameter;
+ }
+ const auto parameter_value = parameters.at(current_param);
+ parameter_values[i] = parameter_value;
++i;
}
@@ -2526,7 +2520,7 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values};
const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}};
const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}};
- const auto database_entry = database::DatabaseEntry{kernel_name, precision, parameter_names, {database_vendor}};
+ const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}};
const auto database_entries = std::vector<database::DatabaseEntry>{database_entry};
const auto database = Database(device_cpp, kernel_name, precision, database_entries);