diff options
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r-- | src/clblast.cpp | 28 |
1 files changed, 11 insertions, 17 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index 3983e5fc..bb338503 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2499,26 +2499,20 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern auto in_cache = false; const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); if (!in_cache) { return StatusCode::kInvalidOverrideKernel; } - for (const auto ¤t_param : current_database.GetParameterNames()) { - if (parameters.find(current_param) == parameters.end()) { - return StatusCode::kMissingOverrideParameter; - } - } - - // Clears the existing program & binary cache for routines with the target kernel - const auto routine_names = Routine::routines_by_kernel.at(kernel_name); - for (const auto &routine_name : routine_names) { - ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, device, precision, routine_name}); - BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name}); + const auto current_parameter_names = current_database.GetParameterNames(); + if (current_parameter_names.size() != parameters.size()) { + return StatusCode::kMissingOverrideParameter; } - // Retrieves the names and values separately + // Retrieves the names and values separately and in the same order as the existing database auto parameter_values = database::Params{0}; - auto parameter_names = std::vector<std::string>(); auto i = size_t{0}; - for (const auto ¶meter : parameters) { - parameter_values[i] = parameter.second; - parameter_names.push_back(parameter.first); + for (const auto ¤t_param : current_parameter_names) { + if (parameters.find(current_param) == parameters.end()) { + return StatusCode::kMissingOverrideParameter; + } + const auto parameter_value = parameters.at(current_param); + parameter_values[i] = parameter_value; ++i; } @@ -2526,7 +2520,7 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values}; const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}}; const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}}; - const auto database_entry = database::DatabaseEntry{kernel_name, precision, parameter_names, {database_vendor}}; + const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}}; const auto database_entries = std::vector<database::DatabaseEntry>{database_entry}; const auto database = Database(device_cpp, kernel_name, precision, database_entries); |