diff options
Diffstat (limited to 'src/utilities/utilities.cpp')
-rw-r--r-- | src/utilities/utilities.cpp | 118 |
1 files changed, 39 insertions, 79 deletions
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index b2ed2f0c..9cf75490 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -24,100 +24,52 @@ namespace clblast { // ================================================================================================= // Returns a scalar with a default value -template <typename T> -T GetScalar() { - return static_cast<T>(2.0); -} +template <typename T> T GetScalar() { return static_cast<T>(2.0); } template float GetScalar<float>(); template double GetScalar<double>(); - -// Specialized version of the above for half-precision -template <> -half GetScalar() { - return FloatToHalf(2.0f); -} - -// Specialized versions of the above for complex data-types -template <> -float2 GetScalar() { - return {2.0f, 0.5f}; -} -template <> -double2 GetScalar() { - return {2.0, 0.5}; -} +template <> half GetScalar() { return FloatToHalf(2.0f); } +template <> float2 GetScalar() { return {2.0f, 0.5f}; } +template <> double2 GetScalar() { return {2.0, 0.5}; } // Returns a scalar of value 0 -template <typename T> -T ConstantZero() { - return static_cast<T>(0.0); -} +template <typename T> T ConstantZero() { return static_cast<T>(0.0); } template float ConstantZero<float>(); template double ConstantZero<double>(); - -// Specialized version of the above for half-precision -template <> -half ConstantZero() { - return FloatToHalf(0.0f); -} - -// Specialized versions of the above for complex data-types -template <> -float2 ConstantZero() { - return {0.0f, 0.0f}; -} -template <> -double2 ConstantZero() { - return {0.0, 0.0}; -} +template <> half ConstantZero() { return FloatToHalf(0.0f); } +template <> float2 ConstantZero() { return {0.0f, 0.0f}; } +template <> double2 ConstantZero() { return {0.0, 0.0}; } // Returns a scalar of value 1 -template <typename T> -T ConstantOne() { - return static_cast<T>(1.0); -} +template <typename T> T ConstantOne() { return static_cast<T>(1.0); } template float ConstantOne<float>(); template double ConstantOne<double>(); - -// Specialized version of the above for half-precision -template <> -half ConstantOne() { - return FloatToHalf(1.0f); -} - -// Specialized versions of the above for complex data-types -template <> -float2 ConstantOne() { - return {1.0f, 0.0f}; -} -template <> -double2 ConstantOne() { - return {1.0, 0.0}; -} +template <> half ConstantOne() { return FloatToHalf(1.0f); } +template <> float2 ConstantOne() { return {1.0f, 0.0f}; } +template <> double2 ConstantOne() { return {1.0, 0.0}; } // Returns a scalar of value -1 -template <typename T> -T ConstantNegOne() { - return static_cast<T>(-1.0); -} +template <typename T> T ConstantNegOne() { return static_cast<T>(-1.0); } template float ConstantNegOne<float>(); template double ConstantNegOne<double>(); +template <> half ConstantNegOne() { return FloatToHalf(-1.0f); } +template <> float2 ConstantNegOne() { return {-1.0f, 0.0f}; } +template <> double2 ConstantNegOne() { return {-1.0, 0.0}; } -// Specialized version of the above for half-precision -template <> -half ConstantNegOne() { - return FloatToHalf(-1.0f); -} - -// Specialized versions of the above for complex data-types -template <> -float2 ConstantNegOne() { - return {-1.0f, 0.0f}; -} -template <> -double2 ConstantNegOne() { - return {-1.0, 0.0}; -} +// Returns a scalar of value 1 +template <typename T> T ConstantTwo() { return static_cast<T>(2.0); } +template float ConstantTwo<float>(); +template double ConstantTwo<double>(); +template <> half ConstantTwo() { return FloatToHalf(2.0f); } +template <> float2 ConstantTwo() { return {2.0f, 0.0f}; } +template <> double2 ConstantTwo() { return {2.0, 0.0}; } + +// Returns a small scalar value just larger than 0 +template <typename T> T SmallConstant() { return static_cast<T>(1e7); } +template float SmallConstant<float>(); +template double SmallConstant<double>(); +template <> half SmallConstant() { return FloatToHalf(1e7); } +template <> float2 SmallConstant() { return {1e7, 0.0f}; } +template <> double2 SmallConstant() { return {1e7, 0.0}; } // Returns the absolute value of a scalar template <typename T> T AbsoluteValue(const T value) { return std::fabs(value); } @@ -127,6 +79,14 @@ template <> half AbsoluteValue(const half value) { return FloatToHalf(std::fabs( template <> float2 AbsoluteValue(const float2 value) { return std::abs(value); } template <> double2 AbsoluteValue(const double2 value) { return std::abs(value); } +// Returns whether a scalar is close to zero +template <typename T> bool IsCloseToZero(const T value) { return (value > -SmallConstant<T>()) && (value < SmallConstant<T>()); } +template bool IsCloseToZero<float>(const float); +template bool IsCloseToZero<double>(const double); +template <> bool IsCloseToZero(const half value) { return IsCloseToZero(HalfToFloat(value)); } +template <> bool IsCloseToZero(const float2 value) { return IsCloseToZero(value.real()) || IsCloseToZero(value.imag()); } +template <> bool IsCloseToZero(const double2 value) { return IsCloseToZero(value.real()) || IsCloseToZero(value.imag()); } + // ================================================================================================= // Implements the string conversion using std::to_string if possible |