summaryrefslogtreecommitdiff
path: root/src/utilities/utilities.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/utilities/utilities.cpp')
-rw-r--r--src/utilities/utilities.cpp118
1 files changed, 39 insertions, 79 deletions
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index b2ed2f0c..9cf75490 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -24,100 +24,52 @@ namespace clblast {
// =================================================================================================
// Returns a scalar with a default value
-template <typename T>
-T GetScalar() {
- return static_cast<T>(2.0);
-}
+template <typename T> T GetScalar() { return static_cast<T>(2.0); }
template float GetScalar<float>();
template double GetScalar<double>();
-
-// Specialized version of the above for half-precision
-template <>
-half GetScalar() {
- return FloatToHalf(2.0f);
-}
-
-// Specialized versions of the above for complex data-types
-template <>
-float2 GetScalar() {
- return {2.0f, 0.5f};
-}
-template <>
-double2 GetScalar() {
- return {2.0, 0.5};
-}
+template <> half GetScalar() { return FloatToHalf(2.0f); }
+template <> float2 GetScalar() { return {2.0f, 0.5f}; }
+template <> double2 GetScalar() { return {2.0, 0.5}; }
// Returns a scalar of value 0
-template <typename T>
-T ConstantZero() {
- return static_cast<T>(0.0);
-}
+template <typename T> T ConstantZero() { return static_cast<T>(0.0); }
template float ConstantZero<float>();
template double ConstantZero<double>();
-
-// Specialized version of the above for half-precision
-template <>
-half ConstantZero() {
- return FloatToHalf(0.0f);
-}
-
-// Specialized versions of the above for complex data-types
-template <>
-float2 ConstantZero() {
- return {0.0f, 0.0f};
-}
-template <>
-double2 ConstantZero() {
- return {0.0, 0.0};
-}
+template <> half ConstantZero() { return FloatToHalf(0.0f); }
+template <> float2 ConstantZero() { return {0.0f, 0.0f}; }
+template <> double2 ConstantZero() { return {0.0, 0.0}; }
// Returns a scalar of value 1
-template <typename T>
-T ConstantOne() {
- return static_cast<T>(1.0);
-}
+template <typename T> T ConstantOne() { return static_cast<T>(1.0); }
template float ConstantOne<float>();
template double ConstantOne<double>();
-
-// Specialized version of the above for half-precision
-template <>
-half ConstantOne() {
- return FloatToHalf(1.0f);
-}
-
-// Specialized versions of the above for complex data-types
-template <>
-float2 ConstantOne() {
- return {1.0f, 0.0f};
-}
-template <>
-double2 ConstantOne() {
- return {1.0, 0.0};
-}
+template <> half ConstantOne() { return FloatToHalf(1.0f); }
+template <> float2 ConstantOne() { return {1.0f, 0.0f}; }
+template <> double2 ConstantOne() { return {1.0, 0.0}; }
// Returns a scalar of value -1
-template <typename T>
-T ConstantNegOne() {
- return static_cast<T>(-1.0);
-}
+template <typename T> T ConstantNegOne() { return static_cast<T>(-1.0); }
template float ConstantNegOne<float>();
template double ConstantNegOne<double>();
+template <> half ConstantNegOne() { return FloatToHalf(-1.0f); }
+template <> float2 ConstantNegOne() { return {-1.0f, 0.0f}; }
+template <> double2 ConstantNegOne() { return {-1.0, 0.0}; }
-// Specialized version of the above for half-precision
-template <>
-half ConstantNegOne() {
- return FloatToHalf(-1.0f);
-}
-
-// Specialized versions of the above for complex data-types
-template <>
-float2 ConstantNegOne() {
- return {-1.0f, 0.0f};
-}
-template <>
-double2 ConstantNegOne() {
- return {-1.0, 0.0};
-}
+// Returns a scalar of value 1
+template <typename T> T ConstantTwo() { return static_cast<T>(2.0); }
+template float ConstantTwo<float>();
+template double ConstantTwo<double>();
+template <> half ConstantTwo() { return FloatToHalf(2.0f); }
+template <> float2 ConstantTwo() { return {2.0f, 0.0f}; }
+template <> double2 ConstantTwo() { return {2.0, 0.0}; }
+
+// Returns a small scalar value just larger than 0
+template <typename T> T SmallConstant() { return static_cast<T>(1e7); }
+template float SmallConstant<float>();
+template double SmallConstant<double>();
+template <> half SmallConstant() { return FloatToHalf(1e7); }
+template <> float2 SmallConstant() { return {1e7, 0.0f}; }
+template <> double2 SmallConstant() { return {1e7, 0.0}; }
// Returns the absolute value of a scalar
template <typename T> T AbsoluteValue(const T value) { return std::fabs(value); }
@@ -127,6 +79,14 @@ template <> half AbsoluteValue(const half value) { return FloatToHalf(std::fabs(
template <> float2 AbsoluteValue(const float2 value) { return std::abs(value); }
template <> double2 AbsoluteValue(const double2 value) { return std::abs(value); }
+// Returns whether a scalar is close to zero
+template <typename T> bool IsCloseToZero(const T value) { return (value > -SmallConstant<T>()) && (value < SmallConstant<T>()); }
+template bool IsCloseToZero<float>(const float);
+template bool IsCloseToZero<double>(const double);
+template <> bool IsCloseToZero(const half value) { return IsCloseToZero(HalfToFloat(value)); }
+template <> bool IsCloseToZero(const float2 value) { return IsCloseToZero(value.real()) || IsCloseToZero(value.imag()); }
+template <> bool IsCloseToZero(const double2 value) { return IsCloseToZero(value.real()) || IsCloseToZero(value.imag()); }
+
// =================================================================================================
// Implements the string conversion using std::to_string if possible