summaryrefslogtreecommitdiff
path: root/test/correctness/tester.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/correctness/tester.cpp')
-rw-r--r--test/correctness/tester.cpp64
1 files changed, 61 insertions, 3 deletions
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp
index eb79008d..6cafd7bc 100644
--- a/test/correctness/tester.cpp
+++ b/test/correctness/tester.cpp
@@ -22,22 +22,46 @@
namespace clblast {
// =================================================================================================
-// Eror margings (relative and absolute)
+// Relative error margins
template <typename T>
float getRelativeErrorMargin() {
return 0.005f; // 0.5% is considered acceptable for float/double-precision
}
+template float getRelativeErrorMargin<float>(); // as the above default
+template float getRelativeErrorMargin<double>(); // as the above default
+template float getRelativeErrorMargin<float2>(); // as the above default
+template float getRelativeErrorMargin<double2>(); // as the above default
template <>
float getRelativeErrorMargin<half>() {
return 0.080f; // 8% (!) error is considered acceptable for half-precision
}
+
+// Absolute error margins
template <typename T>
float getAbsoluteErrorMargin() {
return 0.001f;
}
+template float getAbsoluteErrorMargin<float>(); // as the above default
+template float getAbsoluteErrorMargin<double>(); // as the above default
+template float getAbsoluteErrorMargin<float2>(); // as the above default
+template float getAbsoluteErrorMargin<double2>(); // as the above default
template <>
float getAbsoluteErrorMargin<half>() {
- return 0.10f; // especially small values are inaccurate for half-precision
+ return 0.15f; // especially small values are inaccurate for half-precision
+}
+
+// L2 error margins
+template <typename T>
+double getL2ErrorMargin() {
+ return 0.0f; // zero means don't look at the L2 error margin at all, use the other metrics
+}
+template double getL2ErrorMargin<float>(); // as the above default
+template double getL2ErrorMargin<double>(); // as the above default
+template double getL2ErrorMargin<float2>(); // as the above default
+template double getL2ErrorMargin<double2>(); // as the above default
+template <>
+double getL2ErrorMargin<half>() {
+ return 0.05; // half-precision results are considered OK as long as the L2 error is low enough
}
// Error margin: numbers beyond this value are considered equal to inf or NaN
@@ -144,6 +168,9 @@ Tester<T,U>::Tester(const std::vector<std::string> &arguments, const bool silent
kUnsupportedReference.c_str());
fprintf(stdout, "* Testing with error margins of %.1lf%% (relative) and %.3lf (absolute)\n",
100.0f * getRelativeErrorMargin<T>(), getAbsoluteErrorMargin<T>());
+ if (getL2ErrorMargin<T>() != 0.0f) {
+ fprintf(stdout, "* and a combined maximum allowed L2 error of %.2e\n", getL2ErrorMargin<T>());
+ }
// Initializes clBLAS
#ifdef CLBLAST_REF_CLBLAS
@@ -405,7 +432,7 @@ template <typename T, typename U>
void Tester<T,U>::PrintErrorLog(const std::vector<ErrorLogEntry> &error_log) {
for (auto &entry: error_log) {
if (entry.error_percentage != kStatusError) {
- fprintf(stdout, " Error rate %.1lf%%: ", entry.error_percentage);
+ fprintf(stdout, " Error rate %.2lf%%: ", entry.error_percentage);
}
else {
fprintf(stdout, " Status code %d (expected %d): ",
@@ -499,6 +526,37 @@ bool TestSimilarity(const half val1, const half val2) {
// =================================================================================================
+// Retrieves the squared difference, used for example for computing the L2 error
+template <typename T>
+double SquaredDifference(const T val1, const T val2) {
+ const auto difference = (val1 - val2);
+ return static_cast<double>(difference * difference);
+}
+
+// Compiles the default case for standard data-types
+template double SquaredDifference<float>(const float, const float);
+template double SquaredDifference<double>(const double, const double);
+
+// Specialisations for non-standard data-types
+template <>
+double SquaredDifference(const float2 val1, const float2 val2) {
+ const auto real = SquaredDifference(val1.real(), val2.real());
+ const auto imag = SquaredDifference(val1.imag(), val2.imag());
+ return real + imag;
+}
+template <>
+double SquaredDifference(const double2 val1, const double2 val2) {
+ const auto real = SquaredDifference(val1.real(), val2.real());
+ const auto imag = SquaredDifference(val1.imag(), val2.imag());
+ return real + imag;
+}
+template <>
+double SquaredDifference(const half val1, const half val2) {
+ return SquaredDifference(HalfToFloat(val1), HalfToFloat(val2));
+}
+
+// =================================================================================================
+
// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various
// routines. This function is specialised for the different data-types.
template <> const std::vector<float> GetExampleScalars(const bool full_test) {