diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-02-16 21:12:50 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-02-16 21:12:50 +0100 |
commit | 08bfb75a9d72b6b373d8f18e8be83fe4ea31015b (patch) | |
tree | 93c7861c51c12b07e47a0fc266a004cfd782017a | |
parent | bdc57221bd0279bcdb4f024df54f08a2fe1bb8d4 (diff) |
Added input-sanity checks for the OverrideParameters function
-rw-r--r-- | include/clblast.h | 2 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 4 | ||||
-rw-r--r-- | src/clblast.cpp | 10 | ||||
-rw-r--r-- | src/database/database.cpp | 9 | ||||
-rw-r--r-- | src/database/database.hpp | 3 |
5 files changed, 26 insertions, 2 deletions
diff --git a/include/clblast.h b/include/clblast.h index e7b53d65..1350cb10 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -97,6 +97,8 @@ enum class StatusCode { kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small // Custom additional status codes for CLBlast + kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel + kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel kInvalidLocalMemUsage = -2046, // Not enough local memory available on this device kNoHalfPrecision = -2045, // Half precision (16-bits) not supported by the device kNoDoublePrecision = -2044, // Double precision (64-bits) not supported by the device diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index aaf1b121..f43464b9 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -41,8 +41,8 @@ FILES = [ "/include/clblast_netlib_c.h", "/src/clblast_netlib_c.cpp", ] -HEADER_LINES = [119, 73, 118, 22, 29, 41, 65, 32] -FOOTER_LINES = [23, 128, 19, 18, 6, 6, 9, 2] +HEADER_LINES = [121, 73, 118, 22, 29, 41, 65, 32] +FOOTER_LINES = [23, 138, 19, 18, 6, 6, 9, 2] # Different possibilities for requirements ald_m = "The value of `a_ld` must be at least `m`." diff --git a/src/clblast.cpp b/src/clblast.cpp index 885b849e..871a4804 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2264,6 +2264,16 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern const auto device_cpp = Device(device); const auto device_name = device_cpp.Name(); + // Retrieves the current database values to verify whether the new ones are complete + auto in_cache = false; + const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision, device_name, kernel_name }, &in_cache); + if (!in_cache) { return StatusCode::kInvalidOverrideKernel; } + for (const auto ¤t_param : current_database.GetParameterNames()) { + if (parameters.find(current_param) == parameters.end()) { + return StatusCode::kMissingOverrideParameter; + } + } + // Clears the existing program & binary cache for routines with the target kernel const auto routine_names = Routine::routines_by_kernel.at(kernel_name); for (const auto &routine_name : routine_names) { diff --git a/src/database/database.cpp b/src/database/database.cpp index 8019d558..02d0b139 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -103,6 +103,15 @@ std::string Database::GetDefines() const { return defines; } +// Retrieves the names of all the parameters +std::vector<std::string> Database::GetParameterNames() const { + auto parameter_names = std::vector<std::string>(); + for (auto ¶meter: *parameters_) { + parameter_names.push_back(parameter.first); + } + return parameter_names; +} + // ================================================================================================= // Searches a particular database for the right kernel and precision diff --git a/src/database/database.hpp b/src/database/database.hpp index b6760ec3..b34e0d8a 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -85,6 +85,9 @@ class Database { // Obtain a list of OpenCL pre-processor defines based on the parameters std::string GetDefines() const; + // Retrieves the names of all the parameters + std::vector<std::string> GetParameterNames() const; + private: // Search method for a specified database, returning pointer (possibly a nullptr) ParametersPtr Search(const std::string &this_kernel, const std::string &this_type, |