diff options
-rw-r--r-- | src/clblast.cpp | 28 | ||||
-rw-r--r-- | src/database/database_structure.hpp | 6 | ||||
-rw-r--r-- | src/routine.cpp | 1 | ||||
-rw-r--r-- | test/correctness/misc/override_parameters.cpp | 1 |
4 files changed, 16 insertions, 20 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index 3983e5fc..bb338503 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2499,26 +2499,20 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern auto in_cache = false; const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); if (!in_cache) { return StatusCode::kInvalidOverrideKernel; } - for (const auto ¤t_param : current_database.GetParameterNames()) { - if (parameters.find(current_param) == parameters.end()) { - return StatusCode::kMissingOverrideParameter; - } - } - - // Clears the existing program & binary cache for routines with the target kernel - const auto routine_names = Routine::routines_by_kernel.at(kernel_name); - for (const auto &routine_name : routine_names) { - ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, device, precision, routine_name}); - BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name}); + const auto current_parameter_names = current_database.GetParameterNames(); + if (current_parameter_names.size() != parameters.size()) { + return StatusCode::kMissingOverrideParameter; } - // Retrieves the names and values separately + // Retrieves the names and values separately and in the same order as the existing database auto parameter_values = database::Params{0}; - auto parameter_names = std::vector<std::string>(); auto i = size_t{0}; - for (const auto ¶meter : parameters) { - parameter_values[i] = parameter.second; - parameter_names.push_back(parameter.first); + for (const auto ¤t_param : current_parameter_names) { + if (parameters.find(current_param) == parameters.end()) { + return StatusCode::kMissingOverrideParameter; + } + const auto parameter_value = parameters.at(current_param); + parameter_values[i] = parameter_value; ++i; } @@ -2526,7 +2520,7 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values}; const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}}; const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}}; - const auto database_entry = database::DatabaseEntry{kernel_name, precision, parameter_names, {database_vendor}}; + const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}}; const auto database_entries = std::vector<database::DatabaseEntry>{database_entry}; const auto database = Database(device_cpp, kernel_name, precision, database_entries); diff --git a/src/database/database_structure.hpp b/src/database/database_structure.hpp index 9001b385..176fc556 100644 --- a/src/database/database_structure.hpp +++ b/src/database/database_structure.hpp @@ -17,7 +17,7 @@ #include <string> #include <array> #include <vector> -#include <unordered_map> +#include <map> namespace clblast { // A special namespace to hold all the global constant variables (including the database entries) @@ -29,8 +29,8 @@ namespace database { using Name = std::array<char, 51>; // name as stored in database (50 chars + string terminator) using Params = std::array<size_t, 14>; // parameters as stored in database -// Type alias after extracting from the database (map for improved code readability) -using Parameters = std::unordered_map<std::string, size_t>; // parameters after reading from DB +// Type alias after extracting from the database (sorted map for improved code readability) +using Parameters = std::map<std::string, size_t>; // parameters after reading from DB // The OpenCL device types const std::string kDeviceTypeCPU = "CPU"; diff --git a/src/routine.cpp b/src/routine.cpp index 4f0dd4d1..b25eec56 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -77,6 +77,7 @@ void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatab if (has_db) { continue; } // Builds the parameter database for this device and routine set and stores it in the cache + log_debug("Searching database for kernel '" + kernel_name + "'"); db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name }, Database{ db_(kernel_name) }); diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp index 535d9286..95ece98c 100644 --- a/test/correctness/misc/override_parameters.cpp +++ b/test/correctness/misc/override_parameters.cpp @@ -37,6 +37,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{ { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, { {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, + { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, }; const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{ { {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} }, |