summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/clblast.cpp28
-rw-r--r--src/database/database_structure.hpp6
-rw-r--r--src/routine.cpp1
-rw-r--r--test/correctness/misc/override_parameters.cpp1
4 files changed, 16 insertions, 20 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 3983e5fc..bb338503 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2499,26 +2499,20 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
auto in_cache = false;
const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
if (!in_cache) { return StatusCode::kInvalidOverrideKernel; }
- for (const auto &current_param : current_database.GetParameterNames()) {
- if (parameters.find(current_param) == parameters.end()) {
- return StatusCode::kMissingOverrideParameter;
- }
- }
-
- // Clears the existing program & binary cache for routines with the target kernel
- const auto routine_names = Routine::routines_by_kernel.at(kernel_name);
- for (const auto &routine_name : routine_names) {
- ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, device, precision, routine_name});
- BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name});
+ const auto current_parameter_names = current_database.GetParameterNames();
+ if (current_parameter_names.size() != parameters.size()) {
+ return StatusCode::kMissingOverrideParameter;
}
- // Retrieves the names and values separately
+ // Retrieves the names and values separately and in the same order as the existing database
auto parameter_values = database::Params{0};
- auto parameter_names = std::vector<std::string>();
auto i = size_t{0};
- for (const auto &parameter : parameters) {
- parameter_values[i] = parameter.second;
- parameter_names.push_back(parameter.first);
+ for (const auto &current_param : current_parameter_names) {
+ if (parameters.find(current_param) == parameters.end()) {
+ return StatusCode::kMissingOverrideParameter;
+ }
+ const auto parameter_value = parameters.at(current_param);
+ parameter_values[i] = parameter_value;
++i;
}
@@ -2526,7 +2520,7 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values};
const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}};
const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}};
- const auto database_entry = database::DatabaseEntry{kernel_name, precision, parameter_names, {database_vendor}};
+ const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}};
const auto database_entries = std::vector<database::DatabaseEntry>{database_entry};
const auto database = Database(device_cpp, kernel_name, precision, database_entries);
diff --git a/src/database/database_structure.hpp b/src/database/database_structure.hpp
index 9001b385..176fc556 100644
--- a/src/database/database_structure.hpp
+++ b/src/database/database_structure.hpp
@@ -17,7 +17,7 @@
#include <string>
#include <array>
#include <vector>
-#include <unordered_map>
+#include <map>
namespace clblast {
// A special namespace to hold all the global constant variables (including the database entries)
@@ -29,8 +29,8 @@ namespace database {
using Name = std::array<char, 51>; // name as stored in database (50 chars + string terminator)
using Params = std::array<size_t, 14>; // parameters as stored in database
-// Type alias after extracting from the database (map for improved code readability)
-using Parameters = std::unordered_map<std::string, size_t>; // parameters after reading from DB
+// Type alias after extracting from the database (sorted map for improved code readability)
+using Parameters = std::map<std::string, size_t>; // parameters after reading from DB
// The OpenCL device types
const std::string kDeviceTypeCPU = "CPU";
diff --git a/src/routine.cpp b/src/routine.cpp
index 4f0dd4d1..b25eec56 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -77,6 +77,7 @@ void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatab
if (has_db) { continue; }
// Builds the parameter database for this device and routine set and stores it in the cache
+ log_debug("Searching database for kernel '" + kernel_name + "'");
db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name },
Database{ db_(kernel_name) });
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
index 535d9286..95ece98c 100644
--- a/test/correctness/misc/override_parameters.cpp
+++ b/test/correctness/misc/override_parameters.cpp
@@ -37,6 +37,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{
{ {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
{ {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
+ { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
};
const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{
{ {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} },