diff options
-rw-r--r-- | include/clblast.h | 3 | ||||
-rw-r--r-- | include/clblast_c.h | 16 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 4 | ||||
-rw-r--r-- | src/clblast.cpp | 1 | ||||
-rw-r--r-- | src/clblast_c.cpp | 21 |
5 files changed, 43 insertions, 2 deletions
diff --git a/include/clblast.h b/include/clblast.h index 1350cb10..d9637d15 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -621,6 +621,9 @@ StatusCode PUBLIC_API FillCache(const cl_device_id device); // ================================================================================================= +// Overrides tuning parameters for a specific device-precision-routine combination. The next time +// (and all further times) the target routine is called it will re-compile and use the new +// parameters. StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name, const Precision precision, const std::unordered_map<std::string,size_t> ¶meters); diff --git a/include/clblast_c.h b/include/clblast_c.h index 72f50d83..cd657f3b 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -96,6 +96,8 @@ typedef enum CLBlastStatusCode_ { CLBlastInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small // Custom additional status codes for CLBlast + CLBlastInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel + CLBlastMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel CLBlastInvalidLocalMemUsage = -2046, // Not enough local memory available on this device CLBlastNoHalfPrecision = -2045, // Half precision (16-bits) not supported by the device CLBlastNoDoublePrecision = -2044, // Double precision (64-bits) not supported by the device @@ -117,6 +119,11 @@ typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131, CLBlastDiagonalUnit = 132 } CLBlastDiagonal; typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide; +// Precision enum (values in bits) +typedef enum CLBlastPrecision_ { CLBlastPrecisionHalf = 16, CLBlastPrecisionSingle = 32, + CLBlastPrecisionDouble = 64, CLBlastPrecisionComplexSingle = 3232, + CLBlastPrecisionComplexDouble = 6464 } CLBlastPrecision; + // ================================================================================================= // BLAS level-1 (vector-vector) routines // ================================================================================================= @@ -1338,6 +1345,15 @@ CLBlastStatusCode PUBLIC_API CLBlastFillCache(const cl_device_id device); // ================================================================================================= +// Overrides tuning parameters for a specific device-precision-routine combination. The next time +// (and all further times) the target routine is called it will re-compile and use the new +// parameters. +CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name, + const CLBlastPrecision precision, const size_t num_parameters, + const char** parameters_names, const size_t* parameters_values); + +// ================================================================================================= + #ifdef __cplusplus } // extern "C" #endif diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index f43464b9..9bc48502 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -41,8 +41,8 @@ FILES = [ "/include/clblast_netlib_c.h", "/src/clblast_netlib_c.cpp", ] -HEADER_LINES = [121, 73, 118, 22, 29, 41, 65, 32] -FOOTER_LINES = [23, 138, 19, 18, 6, 6, 9, 2] +HEADER_LINES = [121, 73, 125, 23, 29, 41, 65, 32] +FOOTER_LINES = [26, 139, 28, 38, 6, 6, 9, 2] # Different possibilities for requirements ald_m = "The value of `a_ld` must be at least `m`." diff --git a/src/clblast.cpp b/src/clblast.cpp index 871a4804..a8e4d084 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2255,6 +2255,7 @@ StatusCode FillCache(const cl_device_id device) { // ================================================================================================= +// Overrides the tuning parameters for this device-precision-kernel combination StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name, const Precision precision, const std::unordered_map<std::string,size_t> ¶meters) { diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index e4f2b3ed..79b6a640 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -12,6 +12,7 @@ // ================================================================================================= #include <string> +#include <unordered_map> #include "utilities/utilities.hpp" #include "clblast_c.h" @@ -3484,3 +3485,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) { } // ================================================================================================= + +// Overrides the tuning parameters for this device-precision-kernel combination +CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name, + const CLBlastPrecision precision, const size_t num_parameters, + const char** parameters_names, const size_t* parameters_values) { + try { + const auto kernel_name_cpp = std::string(kernel_name); + const auto precision_cpp = static_cast<clblast::Precision>(precision); + auto parameters = std::unordered_map<std::string, size_t>(); + for (auto i = size_t{0}; i < num_parameters; ++i) { + const auto parameter_name = std::string(parameters_names[i]); + const auto parameter_value = parameters_values[i]; + parameters[parameter_name] = parameter_value; + } + const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters); + return static_cast<CLBlastStatusCode>(status); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +// ================================================================================================= |