summaryrefslogtreecommitdiff
path: root/src/clblast_c.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-02-26 14:51:45 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-02-26 14:51:45 +0100
commitea6790665d228e9ff9ba39983a60cd91611ee1fe (patch)
tree043ca277a867507f97f804cc4057fe50e548b9b1 /src/clblast_c.cpp
parenta145890aaac0087d36b414bd59c247ae4b70b3e5 (diff)
parent0643a29af51f9eb13e2b276d0a0e74590c699d3b (diff)
Merge branch 'development' into triangular_solvers
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r--src/clblast_c.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 07331e3a..6018bcfa 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -12,6 +12,7 @@
// =================================================================================================
#include <string>
+#include <unordered_map>
#include "utilities/utilities.hpp"
#include "clblast_c.h"
@@ -3463,3 +3464,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) {
}
// =================================================================================================
+
+// Overrides the tuning parameters for this device-precision-kernel combination
+CLBlastStatusCode PUBLIC_API CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name,
+ const CLBlastPrecision precision, const size_t num_parameters,
+ const char** parameters_names, const size_t* parameters_values) {
+ try {
+ const auto kernel_name_cpp = std::string(kernel_name);
+ const auto precision_cpp = static_cast<clblast::Precision>(precision);
+ auto parameters = std::unordered_map<std::string, size_t>();
+ for (auto i = size_t{0}; i < num_parameters; ++i) {
+ const auto parameter_name = std::string(parameters_names[i]);
+ const auto parameter_value = parameters_values[i];
+ parameters[parameter_name] = parameter_value;
+ }
+ const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters);
+ return static_cast<CLBlastStatusCode>(status);
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================