summaryrefslogtreecommitdiff
path: root/src/clblast_c.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r--src/clblast_c.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index e4f2b3ed..79b6a640 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -12,6 +12,7 @@
// =================================================================================================
#include <string>
+#include <unordered_map>
#include "utilities/utilities.hpp"
#include "clblast_c.h"
@@ -3484,3 +3485,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) {
}
// =================================================================================================
+
+// Overrides the tuning parameters for this device-precision-kernel combination
+CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name,
+ const CLBlastPrecision precision, const size_t num_parameters,
+ const char** parameters_names, const size_t* parameters_values) {
+ try {
+ const auto kernel_name_cpp = std::string(kernel_name);
+ const auto precision_cpp = static_cast<clblast::Precision>(precision);
+ auto parameters = std::unordered_map<std::string, size_t>();
+ for (auto i = size_t{0}; i < num_parameters; ++i) {
+ const auto parameter_name = std::string(parameters_names[i]);
+ const auto parameter_value = parameters_values[i];
+ parameters[parameter_name] = parameter_value;
+ }
+ const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters);
+ return static_cast<CLBlastStatusCode>(status);
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================