diff options
-rw-r--r-- | doc/clblast.md | 5 | ||||
-rw-r--r-- | include/clblast.h | 2 | ||||
-rw-r--r-- | include/clblast_c.h | 7 | ||||
-rw-r--r-- | include/internal/database.h | 2 | ||||
-rw-r--r-- | include/internal/database/xaxpy.h | 18 | ||||
-rw-r--r-- | include/internal/utilities.h | 4 | ||||
-rw-r--r-- | scripts/database/database.py | 11 | ||||
-rw-r--r-- | src/clblast.cc | 7 | ||||
-rw-r--r-- | src/clblast_c.cc | 12 | ||||
-rw-r--r-- | src/database.cc | 2 | ||||
-rw-r--r-- | src/kernels/common.opencl | 18 | ||||
-rw-r--r-- | src/kernels/level1/xaxpy.opencl | 6 | ||||
-rw-r--r-- | src/routines/level1/xaxpy.cc | 6 | ||||
-rw-r--r-- | src/routines/level3/xgemm.cc | 1 | ||||
-rw-r--r-- | src/tuning/xaxpy.cc | 5 |
15 files changed, 90 insertions, 16 deletions
diff --git a/doc/clblast.md b/doc/clblast.md index 9c9b9a6f..4b36789c 100644 --- a/doc/clblast.md +++ b/doc/clblast.md @@ -181,6 +181,11 @@ StatusCode CLBlastZaxpy(const size_t n, const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, cl_mem y_buffer, const size_t y_offset, const size_t y_inc, cl_command_queue* queue, cl_event* event) +StatusCode CLBlastHaxpy(const size_t n, + const cl_half alpha, + const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, + cl_mem y_buffer, const size_t y_offset, const size_t y_inc, + cl_command_queue* queue, cl_event* event) ``` Arguments to AXPY: diff --git a/include/clblast.h b/include/clblast.h index 5df0f605..74ed6ab2 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -142,7 +142,7 @@ StatusCode Copy(const size_t n, cl_mem y_buffer, const size_t y_offset, const size_t y_inc, cl_command_queue* queue, cl_event* event = nullptr); -// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY +// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY template <typename T> StatusCode Axpy(const size_t n, const T alpha, diff --git a/include/clblast_c.h b/include/clblast_c.h index 8b2bf73c..e36eb68a 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -202,7 +202,7 @@ StatusCode PUBLIC_API CLBlastZcopy(const size_t n, cl_mem y_buffer, const size_t y_offset, const size_t y_inc, cl_command_queue* queue, cl_event* event); -// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY +// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY StatusCode PUBLIC_API CLBlastSaxpy(const size_t n, const float alpha, const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, @@ -223,6 +223,11 @@ StatusCode PUBLIC_API CLBlastZaxpy(const size_t n, const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, cl_mem y_buffer, const size_t y_offset, const size_t y_inc, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHaxpy(const size_t n, + const cl_half alpha, + const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, + cl_mem y_buffer, const size_t y_offset, const size_t y_inc, + cl_command_queue* queue, cl_event* event); // Dot product of two vectors: SDOT/DDOT StatusCode PUBLIC_API CLBlastSdot(const size_t n, diff --git a/include/internal/database.h b/include/internal/database.h index ca79fdad..5bf69358 100644 --- a/include/internal/database.h +++ b/include/internal/database.h @@ -67,7 +67,7 @@ class Database { }; // The database consists of separate database entries, stored together in a vector - static const DatabaseEntry XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble; + static const DatabaseEntry XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble; static const DatabaseEntry XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble; static const DatabaseEntry XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble; static const DatabaseEntry XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble; diff --git a/include/internal/database/xaxpy.h b/include/internal/database/xaxpy.h index 55be0bcb..6c5e478b 100644 --- a/include/internal/database/xaxpy.h +++ b/include/internal/database/xaxpy.h @@ -14,6 +14,24 @@ namespace clblast { // ================================================================================================= +const Database::DatabaseEntry Database::XaxpyHalf = { + "Xaxpy", Precision::kHalf, { + { // Intel GPUs + kDeviceTypeGPU, "Intel", { + { "Intel(R) HD Graphics Skylake ULT GT2", { {"VW",8}, {"WGS",512}, {"WPT",1} } }, + { "default", { {"VW",8}, {"WGS",512}, {"WPT",1} } }, + } + }, + { // Default + kDeviceTypeAll, "default", { + { "default", { {"VW",8}, {"WGS",512}, {"WPT",1} } }, + } + }, + } +}; + +// ================================================================================================= + const Database::DatabaseEntry Database::XaxpySingle = { "Xaxpy", Precision::kSingle, { { // AMD GPUs diff --git a/include/internal/utilities.h b/include/internal/utilities.h index 46d9b8f1..854b3dfe 100644 --- a/include/internal/utilities.h +++ b/include/internal/utilities.h @@ -229,6 +229,10 @@ size_t GetBytes(const Precision precision); template <typename T> bool PrecisionSupported(const Device &device); +// Converts a scalar to a scalar fit as a kernel argument (e.g. half is not supported) +template <typename T> struct RealArg { using Type = T; }; +template <> struct RealArg<half> { using Type = float; }; + // ================================================================================================= } // namespace clblast diff --git a/scripts/database/database.py b/scripts/database/database.py index d14e36cc..87e70fae 100644 --- a/scripts/database/database.py +++ b/scripts/database/database.py @@ -188,13 +188,20 @@ def GetFooter(): # The start of a new C++ precision entry def GetPrecision(family, precision): - precisionstring = "Single" - if precision == "64": + precisionstring = "" + if precision == "16": + precisionstring = "Half" + elif precision == "32": + precisionstring = "Single" + elif precision == "64": precisionstring = "Double" elif precision == "3232": precisionstring = "ComplexSingle" elif precision == "6464": precisionstring = "ComplexDouble" + else: + print("[ERROR] Unknown precision") + sys.exit() return("\n\nconst Database::DatabaseEntry Database::%s%s = {\n \"%s\", Precision::k%s, {\n" % (family.title(), precisionstring, family.title(), precisionstring)) diff --git a/src/clblast.cc b/src/clblast.cc index 8a9465c3..c18dc0a9 100644 --- a/src/clblast.cc +++ b/src/clblast.cc @@ -253,7 +253,7 @@ template StatusCode PUBLIC_API Copy<double2>(const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY +// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY template <typename T> StatusCode Axpy(const size_t n, const T alpha, @@ -289,6 +289,11 @@ template StatusCode PUBLIC_API Axpy<double2>(const size_t, const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Axpy<half>(const size_t, + const half, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Dot product of two vectors: SDOT/DDOT template <typename T> diff --git a/src/clblast_c.cc b/src/clblast_c.cc index 1fc63de2..7642a1e4 100644 --- a/src/clblast_c.cc +++ b/src/clblast_c.cc @@ -312,6 +312,18 @@ StatusCode CLBlastZaxpy(const size_t n, queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHaxpy(const size_t n, + const cl_half alpha, + const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, + cl_mem y_buffer, const size_t y_offset, const size_t y_inc, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Axpy(n, + alpha, + x_buffer, x_offset, x_inc, + y_buffer, y_offset, y_inc, + queue, event); + return static_cast<StatusCode>(status); +} // DOT StatusCode CLBlastSdot(const size_t n, diff --git a/src/database.cc b/src/database.cc index addd85d3..74d69c8b 100644 --- a/src/database.cc +++ b/src/database.cc @@ -29,7 +29,7 @@ namespace clblast { // Initializes the database const std::vector<Database::DatabaseEntry> Database::database = { - XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble, + XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble, diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index 349f9e4f..df9ec35b 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -25,6 +25,11 @@ R"( // ================================================================================================= // Enable support for double-precision +#if PRECISION == 16 + #pragma OPENCL EXTENSION cl_khr_fp16: enable +#endif + +// Enable support for double-precision #if PRECISION == 64 || PRECISION == 6464 #if __OPENCL_VERSION__ <= CL_VERSION_1_1 #pragma OPENCL EXTENSION cl_khr_fp64: enable @@ -38,9 +43,9 @@ R"( typedef half4 real4; typedef half8 real8; typedef half16 real16; - #define ZERO 0.0 - #define ONE 1.0 - #define SMALLEST -1.0e37 + #define ZERO 0.0h + #define ONE 1.0h + #define SMALLEST -1.0e37h // Single-precision #elif PRECISION == 32 @@ -95,6 +100,13 @@ R"( #define SMALLEST -1.0e37 #endif +// Kernel argument scalar +#if PRECISION == 16 + typedef float realarg; +#else + typedef real realarg; +#endif + // Single-element version of a complex number #if PRECISION == 3232 typedef float singlereal; diff --git a/src/kernels/level1/xaxpy.opencl b/src/kernels/level1/xaxpy.opencl index 574beb43..58b7a196 100644 --- a/src/kernels/level1/xaxpy.opencl +++ b/src/kernels/level1/xaxpy.opencl @@ -23,9 +23,10 @@ R"( // Full version of the kernel with offsets and strided accesses __attribute__((reqd_work_group_size(WGS, 1, 1))) -__kernel void Xaxpy(const int n, const real alpha, +__kernel void Xaxpy(const int n, const realarg arg_alpha, const __global real* restrict xgm, const int x_offset, const int x_inc, __global real* ygm, const int y_offset, const int y_inc) { + const real alpha = (real)arg_alpha; // Loops over the work that needs to be done (allows for an arbitrary number of threads) #pragma unroll @@ -40,9 +41,10 @@ __kernel void Xaxpy(const int n, const real alpha, // Faster version of the kernel without offsets and strided accesses. Also assumes that 'n' is // dividable by 'VW', 'WGS' and 'WPT'. __attribute__((reqd_work_group_size(WGS, 1, 1))) -__kernel void XaxpyFast(const int n, const real alpha, +__kernel void XaxpyFast(const int n, const realarg arg_alpha, const __global realV* restrict xgm, __global realV* ygm) { + const real alpha = (real)arg_alpha; #pragma unroll for (int w=0; w<WPT; ++w) { const int id = w*get_global_size(0) + get_global_id(0); diff --git a/src/routines/level1/xaxpy.cc b/src/routines/level1/xaxpy.cc index 96809a57..b7956bf2 100644 --- a/src/routines/level1/xaxpy.cc +++ b/src/routines/level1/xaxpy.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // Specific implementations to get the memory-type based on a template argument +template <> const Precision Xaxpy<half>::precision_ = Precision::kHalf; template <> const Precision Xaxpy<float>::precision_ = Precision::kSingle; template <> const Precision Xaxpy<double>::precision_ = Precision::kDouble; template <> const Precision Xaxpy<float2>::precision_ = Precision::kComplexSingle; @@ -70,13 +71,13 @@ StatusCode Xaxpy<T>::DoAxpy(const size_t n, const T alpha, // Sets the kernel arguments if (use_fast_kernel) { kernel.SetArgument(0, static_cast<int>(n)); - kernel.SetArgument(1, alpha); + kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha)); kernel.SetArgument(2, x_buffer()); kernel.SetArgument(3, y_buffer()); } else { kernel.SetArgument(0, static_cast<int>(n)); - kernel.SetArgument(1, alpha); + kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha)); kernel.SetArgument(2, x_buffer()); kernel.SetArgument(3, static_cast<int>(x_offset)); kernel.SetArgument(4, static_cast<int>(x_inc)); @@ -107,6 +108,7 @@ StatusCode Xaxpy<T>::DoAxpy(const size_t n, const T alpha, // ================================================================================================= // Compiles the templated class +template class Xaxpy<half>; template class Xaxpy<float>; template class Xaxpy<double>; template class Xaxpy<float2>; diff --git a/src/routines/level3/xgemm.cc b/src/routines/level3/xgemm.cc index aa081e81..11116aae 100644 --- a/src/routines/level3/xgemm.cc +++ b/src/routines/level3/xgemm.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // Specific implementations to get the memory-type based on a template argument +template <> const Precision Xgemm<half>::precision_ = Precision::kHalf; template <> const Precision Xgemm<float>::precision_ = Precision::kSingle; template <> const Precision Xgemm<double>::precision_ = Precision::kDouble; template <> const Precision Xgemm<float2>::precision_ = Precision::kComplexSingle; diff --git a/src/tuning/xaxpy.cc b/src/tuning/xaxpy.cc index 31aa6a8e..7f62b811 100644 --- a/src/tuning/xaxpy.cc +++ b/src/tuning/xaxpy.cc @@ -90,7 +90,7 @@ class TuneXaxpy { std::vector<T> &, std::vector<T> &, std::vector<T> &, std::vector<T> &) { tuner.AddArgumentScalar(static_cast<int>(args.n)); - tuner.AddArgumentScalar(args.alpha); + tuner.AddArgumentScalar(static_cast<typename RealArg<T>::Type>(args.alpha)); tuner.AddArgumentInput(x_vec); tuner.AddArgumentOutput(y_vec); } @@ -106,13 +106,14 @@ class TuneXaxpy { } // namespace clblast // Shortcuts to the clblast namespace +using half = clblast::half; using float2 = clblast::float2; using double2 = clblast::double2; // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXaxpy<half>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXaxpy<float>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXaxpy<double>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXaxpy<float2>, float2>(argc, argv); break; |