From f2ba75890c522b4fe1762bfeac3e08667cf9588a Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 12 May 2016 19:56:21 +0200 Subject: Initial changes in preparation for half-precision fp16 support --- src/kernels/common.opencl | 17 ++++++++++++++--- src/routine.cc | 1 + src/utilities.cc | 20 ++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index d401744d..349f9e4f 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -19,7 +19,7 @@ R"( // Parameters set by the tuner or by the database. Here they are given a basic default value in case // this file is used outside of the CLBlast library. #ifndef PRECISION - #define PRECISION 32 // Data-types: single or double precision, complex or regular + #define PRECISION 32 // Data-types: half, single or double precision, complex or regular #endif // ================================================================================================= @@ -31,8 +31,19 @@ R"( #endif #endif +// Half-precision +#if PRECISION == 16 + typedef half real; + typedef half2 real2; + typedef half4 real4; + typedef half8 real8; + typedef half16 real16; + #define ZERO 0.0 + #define ONE 1.0 + #define SMALLEST -1.0e37 + // Single-precision -#if PRECISION == 32 +#elif PRECISION == 32 typedef float real; typedef float2 real2; typedef float4 real4; @@ -68,7 +79,7 @@ R"( #define ONE 1.0f #define SMALLEST -1.0e37f -// Complex Double-precision +// Complex double-precision #elif PRECISION == 6464 typedef struct cdouble {double x; double y;} real; typedef struct cdouble2 {real x; real y;} real2; diff --git a/src/routine.cc b/src/routine.cc index e0cc9a90..5f9b1c89 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -397,6 +397,7 @@ StatusCode Routine::PadCopyTransposeMatrix(EventPointer event, std::vector; template class Routine; template class Routine; template class Routine; diff --git a/src/utilities.cc b/src/utilities.cc index 68a4f02a..5325c3e8 100644 --- a/src/utilities.cc +++ b/src/utilities.cc @@ -29,6 +29,7 @@ std::string ToString(T value) { } template std::string ToString(int value); template std::string ToString(size_t value); +template std::string ToString(half value); template std::string ToString(float value); template std::string ToString(double value); @@ -105,6 +106,9 @@ template T ConvertArgument(const char* value) { return static_cast(std::stoi(value)); } +template <> half ConvertArgument(const char* value) { + return static_cast(std::stod(value)); +} template <> float ConvertArgument(const char* value) { return static_cast(std::stod(value)); } @@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help, // Compiles the above function template int GetArgument(const int, char **, std::string&, const std::string&, const int); template size_t GetArgument(const int, char **, std::string&, const std::string&, const size_t); +template half GetArgument(const int, char **, std::string&, const std::string&, const half); template float GetArgument(const int, char **, std::string&, const std::string&, const float); template double GetArgument(const int, char **, std::string&, const std::string&, const double); template float2 GetArgument(const int, char **, std::string&, const std::string&, const float2); @@ -227,6 +232,16 @@ void PopulateVector(std::vector &vector) { for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); } } +// Specialized versions of the above for half-precision +template <> +void PopulateVector(std::vector &vector) { + auto lower_limit = static_cast(kTestDataLowerLimit); + auto upper_limit = static_cast(kTestDataUpperLimit); + std::mt19937 mt(GetRandomSeed()); + std::uniform_real_distribution dist(lower_limit, upper_limit); + for (auto &element: vector) { element = static_cast(dist(mt)); } +} + // ================================================================================================= // Returns a scalar with a default value @@ -234,6 +249,7 @@ template T GetScalar() { return static_cast(2.0); } +template half GetScalar(); template float GetScalar(); template double GetScalar(); @@ -288,6 +304,10 @@ template <> bool PrecisionSupported(const Device &device) { auto extensions = device.Capabilities(); return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; } +template <> bool PrecisionSupported(const Device &device) { + auto extensions = device.Capabilities(); + return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true; +} // ================================================================================================= } // namespace clblast -- cgit v1.2.3