diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-12 19:56:21 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-12 19:56:21 +0200 |
commit | f2ba75890c522b4fe1762bfeac3e08667cf9588a (patch) | |
tree | 82e22cb72fbfb135570ce3bf3234bd1f60c760c1 /src/utilities.cc | |
parent | 1c72d225c53c123ed810cf3f56f5c92603f7f791 (diff) |
Initial changes in preparation for half-precision fp16 support
Diffstat (limited to 'src/utilities.cc')
-rw-r--r-- | src/utilities.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/src/utilities.cc b/src/utilities.cc index 68a4f02a..5325c3e8 100644 --- a/src/utilities.cc +++ b/src/utilities.cc @@ -29,6 +29,7 @@ std::string ToString(T value) { } template std::string ToString<int>(int value); template std::string ToString<size_t>(size_t value); +template std::string ToString<half>(half value); template std::string ToString<float>(float value); template std::string ToString<double>(double value); @@ -105,6 +106,9 @@ template <typename T> T ConvertArgument(const char* value) { return static_cast<T>(std::stoi(value)); } +template <> half ConvertArgument(const char* value) { + return static_cast<half>(std::stod(value)); +} template <> float ConvertArgument(const char* value) { return static_cast<float>(std::stod(value)); } @@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help, // Compiles the above function template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int); template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t); +template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half); template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float); template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double); template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2); @@ -227,6 +232,16 @@ void PopulateVector(std::vector<double2> &vector) { for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); } } +// Specialized versions of the above for half-precision +template <> +void PopulateVector(std::vector<half> &vector) { + auto lower_limit = static_cast<float>(kTestDataLowerLimit); + auto upper_limit = static_cast<float>(kTestDataUpperLimit); + std::mt19937 mt(GetRandomSeed()); + std::uniform_real_distribution<float> dist(lower_limit, upper_limit); + for (auto &element: vector) { element = static_cast<half>(dist(mt)); } +} + // ================================================================================================= // Returns a scalar with a default value @@ -234,6 +249,7 @@ template <typename T> T GetScalar() { return static_cast<T>(2.0); } +template half GetScalar<half>(); template float GetScalar<float>(); template double GetScalar<double>(); @@ -288,6 +304,10 @@ template <> bool PrecisionSupported<double2>(const Device &device) { auto extensions = device.Capabilities(); return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; } +template <> bool PrecisionSupported<half>(const Device &device) { + auto extensions = device.Capabilities(); + return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true; +} // ================================================================================================= } // namespace clblast |