diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-11-13 21:10:44 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-11-13 21:10:44 +0100 |
commit | 4bac1287f2d49bece72822bf6032e4da56a2dd2d (patch) | |
tree | 77fd5e9fad27179feabc80f97f9f8abf9bbc99d1 /src/utilities | |
parent | 677afd3b96b2cbd3d2aae77e90cab87d2cc1eaa2 (diff) |
Moved square-difference utility function for use in the tuners
Diffstat (limited to 'src/utilities')
-rw-r--r-- | src/utilities/utilities.cpp | 31 | ||||
-rw-r--r-- | src/utilities/utilities.hpp | 6 |
2 files changed, 37 insertions, 0 deletions
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index f2574104..1546fbf5 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -397,6 +397,37 @@ template <> bool PrecisionSupported<half>(const Device &device) { return device. // ================================================================================================= +// Retrieves the squared difference, used for example for computing the L2 error +template <typename T> +double SquaredDifference(const T val1, const T val2) { + const auto difference = (val1 - val2); + return static_cast<double>(difference * difference); +} + +// Compiles the default case for standard data-types +template double SquaredDifference<float>(const float, const float); +template double SquaredDifference<double>(const double, const double); + +// Specialisations for non-standard data-types +template <> +double SquaredDifference(const float2 val1, const float2 val2) { + const auto real = SquaredDifference(val1.real(), val2.real()); + const auto imag = SquaredDifference(val1.imag(), val2.imag()); + return real + imag; +} +template <> +double SquaredDifference(const double2 val1, const double2 val2) { + const auto real = SquaredDifference(val1.real(), val2.real()); + const auto imag = SquaredDifference(val1.imag(), val2.imag()); + return real + imag; +} +template <> +double SquaredDifference(const half val1, const half val2) { + return SquaredDifference(HalfToFloat(val1), HalfToFloat(val2)); +} + +// ================================================================================================= + // High-level info std::string GetDeviceType(const Device& device) { return device.Type(); diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index f56226be..3f90906d 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -323,6 +323,12 @@ bool PrecisionSupported(const Device &device); // ================================================================================================= +// Retrieves the squared difference, used for example for computing the L2 error +template <typename T> +double SquaredDifference(const T val1, const T val2); + +// ================================================================================================= + // Device information in a specific CLBlast form std::string GetDeviceType(const Device& device); std::string GetDeviceVendor(const Device& device); |