summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-13 20:49:34 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-13 20:49:34 +0200
commit120c31a30f933eea12d4dfffd4951fa22102ef5f (patch)
tree853aa6fae0522c9e92fce266c5fddb12a19dafd3 /include
parentf2ba75890c522b4fe1762bfeac3e08667cf9588a (diff)
Initial experimental version of the half-precision HAXPY routine
Diffstat (limited to 'include')
-rw-r--r--include/clblast.h2
-rw-r--r--include/clblast_c.h7
-rw-r--r--include/internal/database.h2
-rw-r--r--include/internal/database/xaxpy.h18
-rw-r--r--include/internal/utilities.h4
5 files changed, 30 insertions, 3 deletions
diff --git a/include/clblast.h b/include/clblast.h
index 5df0f605..74ed6ab2 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -142,7 +142,7 @@ StatusCode Copy(const size_t n,
cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
cl_command_queue* queue, cl_event* event = nullptr);
-// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY
+// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY
template <typename T>
StatusCode Axpy(const size_t n,
const T alpha,
diff --git a/include/clblast_c.h b/include/clblast_c.h
index 8b2bf73c..e36eb68a 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -202,7 +202,7 @@ StatusCode PUBLIC_API CLBlastZcopy(const size_t n,
cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
cl_command_queue* queue, cl_event* event);
-// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY
+// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY
StatusCode PUBLIC_API CLBlastSaxpy(const size_t n,
const float alpha,
const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
@@ -223,6 +223,11 @@ StatusCode PUBLIC_API CLBlastZaxpy(const size_t n,
const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHaxpy(const size_t n,
+ const cl_half alpha,
+ const cl_mem x_buffer, const size_t x_offset, const size_t x_inc,
+ cl_mem y_buffer, const size_t y_offset, const size_t y_inc,
+ cl_command_queue* queue, cl_event* event);
// Dot product of two vectors: SDOT/DDOT
StatusCode PUBLIC_API CLBlastSdot(const size_t n,
diff --git a/include/internal/database.h b/include/internal/database.h
index ca79fdad..5bf69358 100644
--- a/include/internal/database.h
+++ b/include/internal/database.h
@@ -67,7 +67,7 @@ class Database {
};
// The database consists of separate database entries, stored together in a vector
- static const DatabaseEntry XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble;
+ static const DatabaseEntry XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble;
static const DatabaseEntry XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble;
static const DatabaseEntry XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble;
static const DatabaseEntry XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble;
diff --git a/include/internal/database/xaxpy.h b/include/internal/database/xaxpy.h
index 55be0bcb..6c5e478b 100644
--- a/include/internal/database/xaxpy.h
+++ b/include/internal/database/xaxpy.h
@@ -14,6 +14,24 @@
namespace clblast {
// =================================================================================================
+const Database::DatabaseEntry Database::XaxpyHalf = {
+ "Xaxpy", Precision::kHalf, {
+ { // Intel GPUs
+ kDeviceTypeGPU, "Intel", {
+ { "Intel(R) HD Graphics Skylake ULT GT2", { {"VW",8}, {"WGS",512}, {"WPT",1} } },
+ { "default", { {"VW",8}, {"WGS",512}, {"WPT",1} } },
+ }
+ },
+ { // Default
+ kDeviceTypeAll, "default", {
+ { "default", { {"VW",8}, {"WGS",512}, {"WPT",1} } },
+ }
+ },
+ }
+};
+
+// =================================================================================================
+
const Database::DatabaseEntry Database::XaxpySingle = {
"Xaxpy", Precision::kSingle, {
{ // AMD GPUs
diff --git a/include/internal/utilities.h b/include/internal/utilities.h
index 46d9b8f1..854b3dfe 100644
--- a/include/internal/utilities.h
+++ b/include/internal/utilities.h
@@ -229,6 +229,10 @@ size_t GetBytes(const Precision precision);
template <typename T>
bool PrecisionSupported(const Device &device);
+// Converts a scalar to a scalar fit as a kernel argument (e.g. half is not supported)
+template <typename T> struct RealArg { using Type = T; };
+template <> struct RealArg<half> { using Type = float; };
+
// =================================================================================================
} // namespace clblast