diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-12 19:56:21 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-12 19:56:21 +0200 |
commit | f2ba75890c522b4fe1762bfeac3e08667cf9588a (patch) | |
tree | 82e22cb72fbfb135570ce3bf3234bd1f60c760c1 /src | |
parent | 1c72d225c53c123ed810cf3f56f5c92603f7f791 (diff) |
Initial changes in preparation for half-precision fp16 support
Diffstat (limited to 'src')
-rw-r--r-- | src/kernels/common.opencl | 17 | ||||
-rw-r--r-- | src/routine.cc | 1 | ||||
-rw-r--r-- | src/utilities.cc | 20 |
3 files changed, 35 insertions, 3 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index d401744d..349f9e4f 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -19,7 +19,7 @@ R"( // Parameters set by the tuner or by the database. Here they are given a basic default value in case // this file is used outside of the CLBlast library. #ifndef PRECISION - #define PRECISION 32 // Data-types: single or double precision, complex or regular + #define PRECISION 32 // Data-types: half, single or double precision, complex or regular #endif // ================================================================================================= @@ -31,8 +31,19 @@ R"( #endif #endif +// Half-precision +#if PRECISION == 16 + typedef half real; + typedef half2 real2; + typedef half4 real4; + typedef half8 real8; + typedef half16 real16; + #define ZERO 0.0 + #define ONE 1.0 + #define SMALLEST -1.0e37 + // Single-precision -#if PRECISION == 32 +#elif PRECISION == 32 typedef float real; typedef float2 real2; typedef float4 real4; @@ -68,7 +79,7 @@ R"( #define ONE 1.0f #define SMALLEST -1.0e37f -// Complex Double-precision +// Complex double-precision #elif PRECISION == 6464 typedef struct cdouble {double x; double y;} real; typedef struct cdouble2 {real x; real y;} real2; diff --git a/src/routine.cc b/src/routine.cc index e0cc9a90..5f9b1c89 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -397,6 +397,7 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev // ================================================================================================= // Compiles the templated class +template class Routine<half>; template class Routine<float>; template class Routine<double>; template class Routine<float2>; diff --git a/src/utilities.cc b/src/utilities.cc index 68a4f02a..5325c3e8 100644 --- a/src/utilities.cc +++ b/src/utilities.cc @@ -29,6 +29,7 @@ std::string ToString(T value) { } template std::string ToString<int>(int value); template std::string ToString<size_t>(size_t value); +template std::string ToString<half>(half value); template std::string ToString<float>(float value); template std::string ToString<double>(double value); @@ -105,6 +106,9 @@ template <typename T> T ConvertArgument(const char* value) { return static_cast<T>(std::stoi(value)); } +template <> half ConvertArgument(const char* value) { + return static_cast<half>(std::stod(value)); +} template <> float ConvertArgument(const char* value) { return static_cast<float>(std::stod(value)); } @@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help, // Compiles the above function template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int); template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t); +template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half); template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float); template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double); template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2); @@ -227,6 +232,16 @@ void PopulateVector(std::vector<double2> &vector) { for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); } } +// Specialized versions of the above for half-precision +template <> +void PopulateVector(std::vector<half> &vector) { + auto lower_limit = static_cast<float>(kTestDataLowerLimit); + auto upper_limit = static_cast<float>(kTestDataUpperLimit); + std::mt19937 mt(GetRandomSeed()); + std::uniform_real_distribution<float> dist(lower_limit, upper_limit); + for (auto &element: vector) { element = static_cast<half>(dist(mt)); } +} + // ================================================================================================= // Returns a scalar with a default value @@ -234,6 +249,7 @@ template <typename T> T GetScalar() { return static_cast<T>(2.0); } +template half GetScalar<half>(); template float GetScalar<float>(); template double GetScalar<double>(); @@ -288,6 +304,10 @@ template <> bool PrecisionSupported<double2>(const Device &device) { auto extensions = device.Capabilities(); return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; } +template <> bool PrecisionSupported<half>(const Device &device) { + auto extensions = device.Capabilities(); + return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true; +} // ================================================================================================= } // namespace clblast |