summaryrefslogtreecommitdiff
path: root/src/utilities.cc
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-12 19:56:21 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-12 19:56:21 +0200
commitf2ba75890c522b4fe1762bfeac3e08667cf9588a (patch)
tree82e22cb72fbfb135570ce3bf3234bd1f60c760c1 /src/utilities.cc
parent1c72d225c53c123ed810cf3f56f5c92603f7f791 (diff)
Initial changes in preparation for half-precision fp16 support
Diffstat (limited to 'src/utilities.cc')
-rw-r--r--src/utilities.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/src/utilities.cc b/src/utilities.cc
index 68a4f02a..5325c3e8 100644
--- a/src/utilities.cc
+++ b/src/utilities.cc
@@ -29,6 +29,7 @@ std::string ToString(T value) {
}
template std::string ToString<int>(int value);
template std::string ToString<size_t>(size_t value);
+template std::string ToString<half>(half value);
template std::string ToString<float>(float value);
template std::string ToString<double>(double value);
@@ -105,6 +106,9 @@ template <typename T>
T ConvertArgument(const char* value) {
return static_cast<T>(std::stoi(value));
}
+template <> half ConvertArgument(const char* value) {
+ return static_cast<half>(std::stod(value));
+}
template <> float ConvertArgument(const char* value) {
return static_cast<float>(std::stod(value));
}
@@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help,
// Compiles the above function
template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int);
template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t);
+template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half);
template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float);
template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double);
template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2);
@@ -227,6 +232,16 @@ void PopulateVector(std::vector<double2> &vector) {
for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); }
}
+// Specialized versions of the above for half-precision
+template <>
+void PopulateVector(std::vector<half> &vector) {
+ auto lower_limit = static_cast<float>(kTestDataLowerLimit);
+ auto upper_limit = static_cast<float>(kTestDataUpperLimit);
+ std::mt19937 mt(GetRandomSeed());
+ std::uniform_real_distribution<float> dist(lower_limit, upper_limit);
+ for (auto &element: vector) { element = static_cast<half>(dist(mt)); }
+}
+
// =================================================================================================
// Returns a scalar with a default value
@@ -234,6 +249,7 @@ template <typename T>
T GetScalar() {
return static_cast<T>(2.0);
}
+template half GetScalar<half>();
template float GetScalar<float>();
template double GetScalar<double>();
@@ -288,6 +304,10 @@ template <> bool PrecisionSupported<double2>(const Device &device) {
auto extensions = device.Capabilities();
return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
}
+template <> bool PrecisionSupported<half>(const Device &device) {
+ auto extensions = device.Capabilities();
+ return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true;
+}
// =================================================================================================
} // namespace clblast