summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-13 20:49:34 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-13 20:49:34 +0200
commit120c31a30f933eea12d4dfffd4951fa22102ef5f (patch)
tree853aa6fae0522c9e92fce266c5fddb12a19dafd3 /src
parentf2ba75890c522b4fe1762bfeac3e08667cf9588a (diff)
Initial experimental version of the half-precision HAXPY routine
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cc7
-rw-r--r--src/clblast_c.cc12
-rw-r--r--src/database.cc2
-rw-r--r--src/kernels/common.opencl18
-rw-r--r--src/kernels/level1/xaxpy.opencl6
-rw-r--r--src/routines/level1/xaxpy.cc6
-rw-r--r--src/routines/level3/xgemm.cc1
-rw-r--r--src/tuning/xaxpy.cc5
8 files changed, 46 insertions, 11 deletions
diff --git a/src/clblast.cc b/src/clblast.cc
index 8a9465c3..c18dc0a9 100644
--- a/src/clblast.cc
+++ b/src/clblast.cc
@@ -253,7 +253,7 @@ template StatusCode PUBLIC_API Copy<double2>(const size_t,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
-// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY
+// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY
template <typename T>
StatusCode Axpy(const size_t n,
const T alpha,
@@ -289,6 +289,11 @@ template StatusCode PUBLIC_API Axpy<double2>(const size_t,
const cl_mem, const size_t, const size_t,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Axpy<half>(const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
// Dot product of two vectors: SDOT/DDOT
template <typename T>
diff --git a/src/clblast_c.cc b/src/clblast_c.cc
index 1fc63de2..7642a1e4 100644
--- a/src/clblast_c.cc
+++ b/src/clblast_c.cc
@@ -312,6 +312,18 @@ StatusCode CLBlastZaxpy(const size_t n,
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHaxpy(const size_t n,
+ const cl_half alpha,
+ const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
+ cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Axpy(n,
+ alpha,
+ x_buffer, x_offset, x_inc,
+ y_buffer, y_offset, y_inc,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// DOT
StatusCode CLBlastSdot(const size_t n,
diff --git a/src/database.cc b/src/database.cc
index addd85d3..74d69c8b 100644
--- a/src/database.cc
+++ b/src/database.cc
@@ -29,7 +29,7 @@ namespace clblast {
// Initializes the database
const std::vector<Database::DatabaseEntry> Database::database = {
- XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble,
+ XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble,
XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble,
XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble,
XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble,
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index 349f9e4f..df9ec35b 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -25,6 +25,11 @@ R"(
// =================================================================================================
// Enable support for double-precision
+#if PRECISION == 16
+ #pragma OPENCL EXTENSION cl_khr_fp16: enable
+#endif
+
+// Enable support for double-precision
#if PRECISION == 64 || PRECISION == 6464
#if __OPENCL_VERSION__ <= CL_VERSION_1_1
#pragma OPENCL EXTENSION cl_khr_fp64: enable
@@ -38,9 +43,9 @@ R"(
typedef half4 real4;
typedef half8 real8;
typedef half16 real16;
- #define ZERO 0.0
- #define ONE 1.0
- #define SMALLEST -1.0e37
+ #define ZERO 0.0h
+ #define ONE 1.0h
+ #define SMALLEST -1.0e37h
// Single-precision
#elif PRECISION == 32
@@ -95,6 +100,13 @@ R"(
#define SMALLEST -1.0e37
#endif
+// Kernel argument scalar
+#if PRECISION == 16
+ typedef float realarg;
+#else
+ typedef real realarg;
+#endif
+
// Single-element version of a complex number
#if PRECISION == 3232
typedef float singlereal;
diff --git a/src/kernels/level1/xaxpy.opencl b/src/kernels/level1/xaxpy.opencl
index 574beb43..58b7a196 100644
--- a/src/kernels/level1/xaxpy.opencl
+++ b/src/kernels/level1/xaxpy.opencl
@@ -23,9 +23,10 @@ R"(
// Full version of the kernel with offsets and strided accesses
__attribute__((reqd_work_group_size(WGS, 1, 1)))
-__kernel void Xaxpy(const int n, const real alpha,
+__kernel void Xaxpy(const int n, const realarg arg_alpha,
const __global real* restrict xgm, const int x_offset, const int x_inc,
__global real* ygm, const int y_offset, const int y_inc) {
+ const real alpha = (real)arg_alpha;
// Loops over the work that needs to be done (allows for an arbitrary number of threads)
#pragma unroll
@@ -40,9 +41,10 @@ __kernel void Xaxpy(const int n, const real alpha,
// Faster version of the kernel without offsets and strided accesses. Also assumes that 'n' is
// dividable by 'VW', 'WGS' and 'WPT'.
__attribute__((reqd_work_group_size(WGS, 1, 1)))
-__kernel void XaxpyFast(const int n, const real alpha,
+__kernel void XaxpyFast(const int n, const realarg arg_alpha,
const __global realV* restrict xgm,
__global realV* ygm) {
+ const real alpha = (real)arg_alpha;
#pragma unroll
for (int w=0; w<WPT; ++w) {
const int id = w*get_global_size(0) + get_global_id(0);
diff --git a/src/routines/level1/xaxpy.cc b/src/routines/level1/xaxpy.cc
index 96809a57..b7956bf2 100644
--- a/src/routines/level1/xaxpy.cc
+++ b/src/routines/level1/xaxpy.cc
@@ -20,6 +20,7 @@ namespace clblast {
// =================================================================================================
// Specific implementations to get the memory-type based on a template argument
+template <> const Precision Xaxpy<half>::precision_ = Precision::kHalf;
template <> const Precision Xaxpy<float>::precision_ = Precision::kSingle;
template <> const Precision Xaxpy<double>::precision_ = Precision::kDouble;
template <> const Precision Xaxpy<float2>::precision_ = Precision::kComplexSingle;
@@ -70,13 +71,13 @@ StatusCode Xaxpy<T>::DoAxpy(const size_t n, const T alpha,
// Sets the kernel arguments
if (use_fast_kernel) {
kernel.SetArgument(0, static_cast<int>(n));
- kernel.SetArgument(1, alpha);
+ kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha));
kernel.SetArgument(2, x_buffer());
kernel.SetArgument(3, y_buffer());
}
else {
kernel.SetArgument(0, static_cast<int>(n));
- kernel.SetArgument(1, alpha);
+ kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha));
kernel.SetArgument(2, x_buffer());
kernel.SetArgument(3, static_cast<int>(x_offset));
kernel.SetArgument(4, static_cast<int>(x_inc));
@@ -107,6 +108,7 @@ StatusCode Xaxpy<T>::DoAxpy(const size_t n, const T alpha,
// =================================================================================================
// Compiles the templated class
+template class Xaxpy<half>;
template class Xaxpy<float>;
template class Xaxpy<double>;
template class Xaxpy<float2>;
diff --git a/src/routines/level3/xgemm.cc b/src/routines/level3/xgemm.cc
index aa081e81..11116aae 100644
--- a/src/routines/level3/xgemm.cc
+++ b/src/routines/level3/xgemm.cc
@@ -20,6 +20,7 @@ namespace clblast {
// =================================================================================================
// Specific implementations to get the memory-type based on a template argument
+template <> const Precision Xgemm<half>::precision_ = Precision::kHalf;
template <> const Precision Xgemm<float>::precision_ = Precision::kSingle;
template <> const Precision Xgemm<double>::precision_ = Precision::kDouble;
template <> const Precision Xgemm<float2>::precision_ = Precision::kComplexSingle;
diff --git a/src/tuning/xaxpy.cc b/src/tuning/xaxpy.cc
index 31aa6a8e..7f62b811 100644
--- a/src/tuning/xaxpy.cc
+++ b/src/tuning/xaxpy.cc
@@ -90,7 +90,7 @@ class TuneXaxpy {
std::vector<T> &, std::vector<T> &, std::vector<T> &,
std::vector<T> &) {
tuner.AddArgumentScalar(static_cast<int>(args.n));
- tuner.AddArgumentScalar(args.alpha);
+ tuner.AddArgumentScalar(static_cast<typename RealArg<T>::Type>(args.alpha));
tuner.AddArgumentInput(x_vec);
tuner.AddArgumentOutput(y_vec);
}
@@ -106,13 +106,14 @@ class TuneXaxpy {
} // namespace clblast
// Shortcuts to the clblast namespace
+using half = clblast::half;
using float2 = clblast::float2;
using double2 = clblast::double2;
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
- case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode");
+ case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXaxpy<half>, half>(argc, argv); break;
case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXaxpy<float>, float>(argc, argv); break;
case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXaxpy<double>, double>(argc, argv); break;
case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXaxpy<float2>, float2>(argc, argv); break;