summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-02-16 21:12:50 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-02-16 21:12:50 +0100
commit08bfb75a9d72b6b373d8f18e8be83fe4ea31015b (patch)
tree93c7861c51c12b07e47a0fc266a004cfd782017a
parentbdc57221bd0279bcdb4f024df54f08a2fe1bb8d4 (diff)
Added input-sanity checks for the OverrideParameters function
-rw-r--r--include/clblast.h2
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/clblast.cpp10
-rw-r--r--src/database/database.cpp9
-rw-r--r--src/database/database.hpp3
5 files changed, 26 insertions, 2 deletions
diff --git a/include/clblast.h b/include/clblast.h
index e7b53d65..1350cb10 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -97,6 +97,8 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
+ kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
kInvalidLocalMemUsage = -2046, // Not enough local memory available on this device
kNoHalfPrecision = -2045, // Half precision (16-bits) not supported by the device
kNoDoublePrecision = -2044, // Double precision (64-bits) not supported by the device
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index aaf1b121..f43464b9 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -41,8 +41,8 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [119, 73, 118, 22, 29, 41, 65, 32]
-FOOTER_LINES = [23, 128, 19, 18, 6, 6, 9, 2]
+HEADER_LINES = [121, 73, 118, 22, 29, 41, 65, 32]
+FOOTER_LINES = [23, 138, 19, 18, 6, 6, 9, 2]
# Different possibilities for requirements
ald_m = "The value of `a_ld` must be at least `m`."
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 885b849e..871a4804 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2264,6 +2264,16 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
const auto device_cpp = Device(device);
const auto device_name = device_cpp.Name();
+ // Retrieves the current database values to verify whether the new ones are complete
+ auto in_cache = false;
+ const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision, device_name, kernel_name }, &in_cache);
+ if (!in_cache) { return StatusCode::kInvalidOverrideKernel; }
+ for (const auto &current_param : current_database.GetParameterNames()) {
+ if (parameters.find(current_param) == parameters.end()) {
+ return StatusCode::kMissingOverrideParameter;
+ }
+ }
+
// Clears the existing program & binary cache for routines with the target kernel
const auto routine_names = Routine::routines_by_kernel.at(kernel_name);
for (const auto &routine_name : routine_names) {
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 8019d558..02d0b139 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -103,6 +103,15 @@ std::string Database::GetDefines() const {
return defines;
}
+// Retrieves the names of all the parameters
+std::vector<std::string> Database::GetParameterNames() const {
+ auto parameter_names = std::vector<std::string>();
+ for (auto &parameter: *parameters_) {
+ parameter_names.push_back(parameter.first);
+ }
+ return parameter_names;
+}
+
// =================================================================================================
// Searches a particular database for the right kernel and precision
diff --git a/src/database/database.hpp b/src/database/database.hpp
index b6760ec3..b34e0d8a 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -85,6 +85,9 @@ class Database {
// Obtain a list of OpenCL pre-processor defines based on the parameters
std::string GetDefines() const;
+ // Retrieves the names of all the parameters
+ std::vector<std::string> GetParameterNames() const;
+
private:
// Search method for a specified database, returning pointer (possibly a nullptr)
ParametersPtr Search(const std::string &this_kernel, const std::string &this_type,