summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clblast.h3
-rw-r--r--include/clblast_c.h16
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/clblast.cpp1
-rw-r--r--src/clblast_c.cpp21
5 files changed, 43 insertions, 2 deletions
diff --git a/include/clblast.h b/include/clblast.h
index 1350cb10..d9637d15 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -621,6 +621,9 @@ StatusCode PUBLIC_API FillCache(const cl_device_id device);
// =================================================================================================
+// Overrides tuning parameters for a specific device-precision-routine combination. The next time
+// (and all further times) the target routine is called it will re-compile and use the new
+// parameters.
StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name,
const Precision precision,
const std::unordered_map<std::string,size_t> &parameters);
diff --git a/include/clblast_c.h b/include/clblast_c.h
index 72f50d83..cd657f3b 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -96,6 +96,8 @@ typedef enum CLBlastStatusCode_ {
CLBlastInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ CLBlastInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
+ CLBlastMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
CLBlastInvalidLocalMemUsage = -2046, // Not enough local memory available on this device
CLBlastNoHalfPrecision = -2045, // Half precision (16-bits) not supported by the device
CLBlastNoDoublePrecision = -2044, // Double precision (64-bits) not supported by the device
@@ -117,6 +119,11 @@ typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131,
CLBlastDiagonalUnit = 132 } CLBlastDiagonal;
typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide;
+// Precision enum (values in bits)
+typedef enum CLBlastPrecision_ { CLBlastPrecisionHalf = 16, CLBlastPrecisionSingle = 32,
+ CLBlastPrecisionDouble = 64, CLBlastPrecisionComplexSingle = 3232,
+ CLBlastPrecisionComplexDouble = 6464 } CLBlastPrecision;
+
// =================================================================================================
// BLAS level-1 (vector-vector) routines
// =================================================================================================
@@ -1338,6 +1345,15 @@ CLBlastStatusCode PUBLIC_API CLBlastFillCache(const cl_device_id device);
// =================================================================================================
+// Overrides tuning parameters for a specific device-precision-routine combination. The next time
+// (and all further times) the target routine is called it will re-compile and use the new
+// parameters.
+CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name,
+ const CLBlastPrecision precision, const size_t num_parameters,
+ const char** parameters_names, const size_t* parameters_values);
+
+// =================================================================================================
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index f43464b9..9bc48502 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -41,8 +41,8 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [121, 73, 118, 22, 29, 41, 65, 32]
-FOOTER_LINES = [23, 138, 19, 18, 6, 6, 9, 2]
+HEADER_LINES = [121, 73, 125, 23, 29, 41, 65, 32]
+FOOTER_LINES = [26, 139, 28, 38, 6, 6, 9, 2]
# Different possibilities for requirements
ald_m = "The value of `a_ld` must be at least `m`."
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 871a4804..a8e4d084 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2255,6 +2255,7 @@ StatusCode FillCache(const cl_device_id device) {
// =================================================================================================
+// Overrides the tuning parameters for this device-precision-kernel combination
StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name,
const Precision precision,
const std::unordered_map<std::string,size_t> &parameters) {
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index e4f2b3ed..79b6a640 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -12,6 +12,7 @@
// =================================================================================================
#include <string>
+#include <unordered_map>
#include "utilities/utilities.hpp"
#include "clblast_c.h"
@@ -3484,3 +3485,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) {
}
// =================================================================================================
+
+// Overrides the tuning parameters for this device-precision-kernel combination
+CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name,
+ const CLBlastPrecision precision, const size_t num_parameters,
+ const char** parameters_names, const size_t* parameters_values) {
+ try {
+ const auto kernel_name_cpp = std::string(kernel_name);
+ const auto precision_cpp = static_cast<clblast::Precision>(precision);
+ auto parameters = std::unordered_map<std::string, size_t>();
+ for (auto i = size_t{0}; i < num_parameters; ++i) {
+ const auto parameter_name = std::string(parameters_names[i]);
+ const auto parameter_value = parameters_values[i];
+ parameters[parameter_name] = parameter_value;
+ }
+ const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters);
+ return static_cast<CLBlastStatusCode>(status);
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================