summaryrefslogtreecommitdiff
path: root/src/utilities
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-11-13 21:10:44 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-11-13 21:10:44 +0100
commit4bac1287f2d49bece72822bf6032e4da56a2dd2d (patch)
tree77fd5e9fad27179feabc80f97f9f8abf9bbc99d1 /src/utilities
parent677afd3b96b2cbd3d2aae77e90cab87d2cc1eaa2 (diff)
Moved square-difference utility function for use in the tuners
Diffstat (limited to 'src/utilities')
-rw-r--r--src/utilities/utilities.cpp31
-rw-r--r--src/utilities/utilities.hpp6
2 files changed, 37 insertions, 0 deletions
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index f2574104..1546fbf5 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -397,6 +397,37 @@ template <> bool PrecisionSupported<half>(const Device &device) { return device.
// =================================================================================================
+// Retrieves the squared difference, used for example for computing the L2 error
+template <typename T>
+double SquaredDifference(const T val1, const T val2) {
+ const auto difference = (val1 - val2);
+ return static_cast<double>(difference * difference);
+}
+
+// Compiles the default case for standard data-types
+template double SquaredDifference<float>(const float, const float);
+template double SquaredDifference<double>(const double, const double);
+
+// Specialisations for non-standard data-types
+template <>
+double SquaredDifference(const float2 val1, const float2 val2) {
+ const auto real = SquaredDifference(val1.real(), val2.real());
+ const auto imag = SquaredDifference(val1.imag(), val2.imag());
+ return real + imag;
+}
+template <>
+double SquaredDifference(const double2 val1, const double2 val2) {
+ const auto real = SquaredDifference(val1.real(), val2.real());
+ const auto imag = SquaredDifference(val1.imag(), val2.imag());
+ return real + imag;
+}
+template <>
+double SquaredDifference(const half val1, const half val2) {
+ return SquaredDifference(HalfToFloat(val1), HalfToFloat(val2));
+}
+
+// =================================================================================================
+
// High-level info
std::string GetDeviceType(const Device& device) {
return device.Type();
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index f56226be..3f90906d 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -323,6 +323,12 @@ bool PrecisionSupported(const Device &device);
// =================================================================================================
+// Retrieves the squared difference, used for example for computing the L2 error
+template <typename T>
+double SquaredDifference(const T val1, const T val2);
+
+// =================================================================================================
+
// Device information in a specific CLBlast form
std::string GetDeviceType(const Device& device);
std::string GetDeviceVendor(const Device& device);