summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-12 19:56:21 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-12 19:56:21 +0200
commitf2ba75890c522b4fe1762bfeac3e08667cf9588a (patch)
tree82e22cb72fbfb135570ce3bf3234bd1f60c760c1 /src
parent1c72d225c53c123ed810cf3f56f5c92603f7f791 (diff)
Initial changes in preparation for half-precision fp16 support
Diffstat (limited to 'src')
-rw-r--r--src/kernels/common.opencl17
-rw-r--r--src/routine.cc1
-rw-r--r--src/utilities.cc20
3 files changed, 35 insertions, 3 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index d401744d..349f9e4f 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -19,7 +19,7 @@ R"(
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
// this file is used outside of the CLBlast library.
#ifndef PRECISION
- #define PRECISION 32 // Data-types: single or double precision, complex or regular
+ #define PRECISION 32 // Data-types: half, single or double precision, complex or regular
#endif
// =================================================================================================
@@ -31,8 +31,19 @@ R"(
#endif
#endif
+// Half-precision
+#if PRECISION == 16
+ typedef half real;
+ typedef half2 real2;
+ typedef half4 real4;
+ typedef half8 real8;
+ typedef half16 real16;
+ #define ZERO 0.0
+ #define ONE 1.0
+ #define SMALLEST -1.0e37
+
// Single-precision
-#if PRECISION == 32
+#elif PRECISION == 32
typedef float real;
typedef float2 real2;
typedef float4 real4;
@@ -68,7 +79,7 @@ R"(
#define ONE 1.0f
#define SMALLEST -1.0e37f
-// Complex Double-precision
+// Complex double-precision
#elif PRECISION == 6464
typedef struct cdouble {double x; double y;} real;
typedef struct cdouble2 {real x; real y;} real2;
diff --git a/src/routine.cc b/src/routine.cc
index e0cc9a90..5f9b1c89 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -397,6 +397,7 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
// =================================================================================================
// Compiles the templated class
+template class Routine<half>;
template class Routine<float>;
template class Routine<double>;
template class Routine<float2>;
diff --git a/src/utilities.cc b/src/utilities.cc
index 68a4f02a..5325c3e8 100644
--- a/src/utilities.cc
+++ b/src/utilities.cc
@@ -29,6 +29,7 @@ std::string ToString(T value) {
}
template std::string ToString<int>(int value);
template std::string ToString<size_t>(size_t value);
+template std::string ToString<half>(half value);
template std::string ToString<float>(float value);
template std::string ToString<double>(double value);
@@ -105,6 +106,9 @@ template <typename T>
T ConvertArgument(const char* value) {
return static_cast<T>(std::stoi(value));
}
+template <> half ConvertArgument(const char* value) {
+ return static_cast<half>(std::stod(value));
+}
template <> float ConvertArgument(const char* value) {
return static_cast<float>(std::stod(value));
}
@@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help,
// Compiles the above function
template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int);
template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t);
+template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half);
template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float);
template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double);
template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2);
@@ -227,6 +232,16 @@ void PopulateVector(std::vector<double2> &vector) {
for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); }
}
+// Specialized versions of the above for half-precision
+template <>
+void PopulateVector(std::vector<half> &vector) {
+ auto lower_limit = static_cast<float>(kTestDataLowerLimit);
+ auto upper_limit = static_cast<float>(kTestDataUpperLimit);
+ std::mt19937 mt(GetRandomSeed());
+ std::uniform_real_distribution<float> dist(lower_limit, upper_limit);
+ for (auto &element: vector) { element = static_cast<half>(dist(mt)); }
+}
+
// =================================================================================================
// Returns a scalar with a default value
@@ -234,6 +249,7 @@ template <typename T>
T GetScalar() {
return static_cast<T>(2.0);
}
+template half GetScalar<half>();
template float GetScalar<float>();
template double GetScalar<double>();
@@ -288,6 +304,10 @@ template <> bool PrecisionSupported<double2>(const Device &device) {
auto extensions = device.Capabilities();
return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
}
+template <> bool PrecisionSupported<half>(const Device &device) {
+ auto extensions = device.Capabilities();
+ return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true;
+}
// =================================================================================================
} // namespace clblast