summaryrefslogtreecommitdiff
path: root/test/correctness
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-07-08 07:21:44 +0200
committerCNugteren <web@cedricnugteren.nl>2015-07-08 07:21:44 +0200
commit82469fc76432f37d955d87b05bf80b02026d19ff (patch)
tree30d12a1c47ddcdf9ed59ee5bbddda2f019c2140c /test/correctness
parent599f9a70a6bb2388c27b7981276e4d39497f90fb (diff)
The testers now distinguish between the memory and alpha/beta data-type
Diffstat (limited to 'test/correctness')
-rw-r--r--test/correctness/routines/xaxpy.cc8
-rw-r--r--test/correctness/routines/xgemm.cc8
-rw-r--r--test/correctness/routines/xgemv.cc8
-rw-r--r--test/correctness/routines/xherk.cc92
-rw-r--r--test/correctness/routines/xsymm.cc8
-rw-r--r--test/correctness/routines/xsyr2k.cc8
-rw-r--r--test/correctness/routines/xsyrk.cc8
-rw-r--r--test/correctness/routines/xtrmm.cc8
-rw-r--r--test/correctness/testblas.cc46
-rw-r--r--test/correctness/testblas.h41
-rw-r--r--test/correctness/tester.cc217
-rw-r--r--test/correctness/tester.h54
12 files changed, 312 insertions, 194 deletions
diff --git a/test/correctness/routines/xaxpy.cc b/test/correctness/routines/xaxpy.cc
index 89315a0d..cf23ca9f 100644
--- a/test/correctness/routines/xaxpy.cc
+++ b/test/correctness/routines/xaxpy.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXaxpy<T>::GetOptions(),
- TestXaxpy<T>::RunRoutine, TestXaxpy<T>::RunReference,
- TestXaxpy<T>::DownloadResult, TestXaxpy<T>::GetResultIndex,
- TestXaxpy<T>::ResultID1, TestXaxpy<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXaxpy<T>::GetOptions(),
+ TestXaxpy<T>::RunRoutine, TestXaxpy<T>::RunReference,
+ TestXaxpy<T>::DownloadResult, TestXaxpy<T>::GetResultIndex,
+ TestXaxpy<T>::ResultID1, TestXaxpy<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/routines/xgemm.cc b/test/correctness/routines/xgemm.cc
index 72843d45..8a50e1ca 100644
--- a/test/correctness/routines/xgemm.cc
+++ b/test/correctness/routines/xgemm.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXgemm<T>::GetOptions(),
- TestXgemm<T>::RunRoutine, TestXgemm<T>::RunReference,
- TestXgemm<T>::DownloadResult, TestXgemm<T>::GetResultIndex,
- TestXgemm<T>::ResultID1, TestXgemm<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXgemm<T>::GetOptions(),
+ TestXgemm<T>::RunRoutine, TestXgemm<T>::RunReference,
+ TestXgemm<T>::DownloadResult, TestXgemm<T>::GetResultIndex,
+ TestXgemm<T>::ResultID1, TestXgemm<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/routines/xgemv.cc b/test/correctness/routines/xgemv.cc
index f1100810..50ce4699 100644
--- a/test/correctness/routines/xgemv.cc
+++ b/test/correctness/routines/xgemv.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXgemv<T>::GetOptions(),
- TestXgemv<T>::RunRoutine, TestXgemv<T>::RunReference,
- TestXgemv<T>::DownloadResult, TestXgemv<T>::GetResultIndex,
- TestXgemv<T>::ResultID1, TestXgemv<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXgemv<T>::GetOptions(),
+ TestXgemv<T>::RunRoutine, TestXgemv<T>::RunReference,
+ TestXgemv<T>::DownloadResult, TestXgemv<T>::GetResultIndex,
+ TestXgemv<T>::ResultID1, TestXgemv<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/routines/xherk.cc b/test/correctness/routines/xherk.cc
new file mode 100644
index 00000000..dc5c6caf
--- /dev/null
+++ b/test/correctness/routines/xherk.cc
@@ -0,0 +1,92 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the tests for the Xherk routine.
+//
+// =================================================================================================
+
+#include "correctness/testblas.h"
+#include "routines/xherk.h"
+
+namespace clblast {
+// =================================================================================================
+
+// The correctness tester
+template <typename T, typename U>
+void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
+
+ // Creates a tester
+ TestBlas<T,U> tester{argc, argv, silent, name, TestXherk<T,U>::GetOptions(),
+ TestXherk<T,U>::RunRoutine, TestXherk<T,U>::RunReference,
+ TestXherk<T,U>::DownloadResult, TestXherk<T,U>::GetResultIndex,
+ TestXherk<T,U>::ResultID1, TestXherk<T,U>::ResultID2};
+
+ // This variable holds the arguments relevant for this routine
+ auto args = Arguments<U>{};
+
+ // Loops over the test-cases from a data-layout point of view
+ for (auto &layout: tester.kLayouts) { args.layout = layout;
+ for (auto &triangle: tester.kTriangles) { args.triangle = triangle;
+ for (auto &a_transpose: {Transpose::kNo, Transpose::kConjugate}) { // Regular transpose not a
+ args.a_transpose = a_transpose; // valid BLAS option
+
+ // Creates the arguments vector for the regular tests
+ auto regular_test_vector = std::vector<Arguments<U>>{};
+ for (auto &n: tester.kMatrixDims) { args.n = n;
+ for (auto &k: tester.kMatrixDims) { args.k = k;
+ for (auto &a_ld: tester.kMatrixDims) { args.a_ld = a_ld;
+ for (auto &a_offset: tester.kOffsets) { args.a_offset = a_offset;
+ for (auto &c_ld: tester.kMatrixDims) { args.c_ld = c_ld;
+ for (auto &c_offset: tester.kOffsets) { args.c_offset = c_offset;
+ for (auto &alpha: tester.kAlphaValues) { args.alpha = alpha;
+ for (auto &beta: tester.kBetaValues) { args.beta = beta;
+ args.a_size = TestXherk<T,U>::GetSizeA(args);
+ args.c_size = TestXherk<T,U>::GetSizeC(args);
+ if (args.a_size<1 || args.c_size<1) { continue; }
+ regular_test_vector.push_back(args);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Creates the arguments vector for the invalid-buffer tests
+ auto invalid_test_vector = std::vector<Arguments<U>>{};
+ args.n = args.k = tester.kBufferSize;
+ args.a_ld = args.c_ld = tester.kBufferSize;
+ args.a_offset = args.c_offset = 0;
+ for (auto &a_size: tester.kMatSizes) { args.a_size = a_size;
+ for (auto &c_size: tester.kMatSizes) { args.c_size = c_size;
+ invalid_test_vector.push_back(args);
+ }
+ }
+
+ // Runs the tests
+ const auto case_name = ToString(layout)+" "+ToString(triangle)+" "+ToString(a_transpose);
+ tester.TestRegular(regular_test_vector, case_name);
+ tester.TestInvalid(invalid_test_vector, case_name);
+ }
+ }
+ }
+}
+
+// =================================================================================================
+} // namespace clblast
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ clblast::RunTest<clblast::float2,float>(argc, argv, false, "CHERK");
+ clblast::RunTest<clblast::double2,double>(argc, argv, true, "ZHERK");
+ return 0;
+}
+
+// =================================================================================================
diff --git a/test/correctness/routines/xsymm.cc b/test/correctness/routines/xsymm.cc
index 3da654c3..a919a056 100644
--- a/test/correctness/routines/xsymm.cc
+++ b/test/correctness/routines/xsymm.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXsymm<T>::GetOptions(),
- TestXsymm<T>::RunRoutine, TestXsymm<T>::RunReference,
- TestXsymm<T>::DownloadResult, TestXsymm<T>::GetResultIndex,
- TestXsymm<T>::ResultID1, TestXsymm<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXsymm<T>::GetOptions(),
+ TestXsymm<T>::RunRoutine, TestXsymm<T>::RunReference,
+ TestXsymm<T>::DownloadResult, TestXsymm<T>::GetResultIndex,
+ TestXsymm<T>::ResultID1, TestXsymm<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/routines/xsyr2k.cc b/test/correctness/routines/xsyr2k.cc
index 8b03087c..736aa4e5 100644
--- a/test/correctness/routines/xsyr2k.cc
+++ b/test/correctness/routines/xsyr2k.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXsyr2k<T>::GetOptions(),
- TestXsyr2k<T>::RunRoutine, TestXsyr2k<T>::RunReference,
- TestXsyr2k<T>::DownloadResult, TestXsyr2k<T>::GetResultIndex,
- TestXsyr2k<T>::ResultID1, TestXsyr2k<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXsyr2k<T>::GetOptions(),
+ TestXsyr2k<T>::RunRoutine, TestXsyr2k<T>::RunReference,
+ TestXsyr2k<T>::DownloadResult, TestXsyr2k<T>::GetResultIndex,
+ TestXsyr2k<T>::ResultID1, TestXsyr2k<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/routines/xsyrk.cc b/test/correctness/routines/xsyrk.cc
index d4552a78..a62a0ebf 100644
--- a/test/correctness/routines/xsyrk.cc
+++ b/test/correctness/routines/xsyrk.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXsyrk<T>::GetOptions(),
- TestXsyrk<T>::RunRoutine, TestXsyrk<T>::RunReference,
- TestXsyrk<T>::DownloadResult, TestXsyrk<T>::GetResultIndex,
- TestXsyrk<T>::ResultID1, TestXsyrk<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXsyrk<T>::GetOptions(),
+ TestXsyrk<T>::RunRoutine, TestXsyrk<T>::RunReference,
+ TestXsyrk<T>::DownloadResult, TestXsyrk<T>::GetResultIndex,
+ TestXsyrk<T>::ResultID1, TestXsyrk<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/routines/xtrmm.cc b/test/correctness/routines/xtrmm.cc
index 943fb664..0bb6294c 100644
--- a/test/correctness/routines/xtrmm.cc
+++ b/test/correctness/routines/xtrmm.cc
@@ -22,10 +22,10 @@ template <typename T>
void RunTest(int argc, char *argv[], const bool silent, const std::string &name) {
// Creates a tester
- TestBlas<T> tester{argc, argv, silent, name, TestXtrmm<T>::GetOptions(),
- TestXtrmm<T>::RunRoutine, TestXtrmm<T>::RunReference,
- TestXtrmm<T>::DownloadResult, TestXtrmm<T>::GetResultIndex,
- TestXtrmm<T>::ResultID1, TestXtrmm<T>::ResultID2};
+ TestBlas<T,T> tester{argc, argv, silent, name, TestXtrmm<T>::GetOptions(),
+ TestXtrmm<T>::RunRoutine, TestXtrmm<T>::RunReference,
+ TestXtrmm<T>::DownloadResult, TestXtrmm<T>::GetResultIndex,
+ TestXtrmm<T>::ResultID1, TestXtrmm<T>::ResultID2};
// This variable holds the arguments relevant for this routine
auto args = Arguments<T>{};
diff --git a/test/correctness/testblas.cc b/test/correctness/testblas.cc
index 0e72e429..5951b177 100644
--- a/test/correctness/testblas.cc
+++ b/test/correctness/testblas.cc
@@ -19,21 +19,23 @@ namespace clblast {
// =================================================================================================
// The transpose-options to test with (data-type dependent)
-template <> const std::vector<Transpose> TestBlas<float>::kTransposes = {Transpose::kNo, Transpose::kYes};
-template <> const std::vector<Transpose> TestBlas<double>::kTransposes = {Transpose::kNo, Transpose::kYes};
-template <> const std::vector<Transpose> TestBlas<float2>::kTransposes = {Transpose::kNo, Transpose::kYes, Transpose::kConjugate};
-template <> const std::vector<Transpose> TestBlas<double2>::kTransposes = {Transpose::kNo, Transpose::kYes, Transpose::kConjugate};
+template <> const std::vector<Transpose> TestBlas<float,float>::kTransposes = {Transpose::kNo, Transpose::kYes};
+template <> const std::vector<Transpose> TestBlas<double,double>::kTransposes = {Transpose::kNo, Transpose::kYes};
+template <> const std::vector<Transpose> TestBlas<float2,float2>::kTransposes = {Transpose::kNo, Transpose::kYes, Transpose::kConjugate};
+template <> const std::vector<Transpose> TestBlas<double2,double2>::kTransposes = {Transpose::kNo, Transpose::kYes, Transpose::kConjugate};
+template <> const std::vector<Transpose> TestBlas<float2,float>::kTransposes = {Transpose::kNo, Transpose::kConjugate};
+template <> const std::vector<Transpose> TestBlas<double2,double>::kTransposes = {Transpose::kNo, Transpose::kConjugate};
// =================================================================================================
// Constructor, initializes the base class tester and input data
-template <typename T>
-TestBlas<T>::TestBlas(int argc, char *argv[], const bool silent,
- const std::string &name, const std::vector<std::string> &options,
- const Routine run_routine, const Routine run_reference,
- const ResultGet get_result, const ResultIndex get_index,
- const ResultIterator get_id1, const ResultIterator get_id2):
- Tester<T>{argc, argv, silent, name, options},
+template <typename T, typename U>
+TestBlas<T,U>::TestBlas(int argc, char *argv[], const bool silent,
+ const std::string &name, const std::vector<std::string> &options,
+ const Routine run_routine, const Routine run_reference,
+ const ResultGet get_result, const ResultIndex get_index,
+ const ResultIterator get_id1, const ResultIterator get_id2):
+ Tester<T,U>{argc, argv, silent, name, options},
run_routine_(run_routine),
run_reference_(run_reference),
get_result_(get_result),
@@ -65,9 +67,9 @@ TestBlas<T>::TestBlas(int argc, char *argv[], const bool silent,
// ===============================================================================================
// Tests the routine for a wide variety of parameters
-template <typename T>
-void TestBlas<T>::TestRegular(std::vector<Arguments<T>> &test_vector, const std::string &name) {
- if (!PrecisionSupported()) { return; }
+template <typename T, typename U>
+void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const std::string &name) {
+ if (!PrecisionSupported<T>(device_)) { return; }
TestStart("regular behaviour", name);
// Iterates over all the to-be-tested combinations of arguments
@@ -132,9 +134,9 @@ void TestBlas<T>::TestRegular(std::vector<Arguments<T>> &test_vector, const std:
// Tests the routine for cases with invalid OpenCL memory buffer sizes. Tests only on return-types,
// does not test for results (if any).
-template <typename T>
-void TestBlas<T>::TestInvalid(std::vector<Arguments<T>> &test_vector, const std::string &name) {
- if (!PrecisionSupported()) { return; }
+template <typename T, typename U>
+void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const std::string &name) {
+ if (!PrecisionSupported<T>(device_)) { return; }
TestStart("invalid buffer sizes", name);
// Iterates over all the to-be-tested combinations of arguments
@@ -176,10 +178,12 @@ void TestBlas<T>::TestInvalid(std::vector<Arguments<T>> &test_vector, const std:
// =================================================================================================
// Compiles the templated class
-template class TestBlas<float>;
-template class TestBlas<double>;
-template class TestBlas<float2>;
-template class TestBlas<double2>;
+template class TestBlas<float, float>;
+template class TestBlas<double, double>;
+template class TestBlas<float2, float2>;
+template class TestBlas<double2, double2>;
+template class TestBlas<float2, float>;
+template class TestBlas<double2, double>;
// =================================================================================================
} // namespace clblast
diff --git a/test/correctness/testblas.h b/test/correctness/testblas.h
index 7469700d..96c140c1 100644
--- a/test/correctness/testblas.h
+++ b/test/correctness/testblas.h
@@ -9,6 +9,8 @@
//
// This file tests any CLBlast routine. It contains two types of tests: one testing all sorts of
// input combinations, and one deliberatly testing with invalid values.
+// Typename T: the data-type of the routine's memory buffers (==precision)
+// Typename U: the data-type of the alpha and beta arguments
//
// =================================================================================================
@@ -24,23 +26,22 @@ namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
-template <typename T>
-class TestBlas: public Tester<T> {
+template <typename T, typename U>
+class TestBlas: public Tester<T,U> {
public:
// Uses several variables from the Tester class
- using Tester<T>::context_;
- using Tester<T>::queue_;
+ using Tester<T,U>::context_;
+ using Tester<T,U>::queue_;
+ using Tester<T,U>::full_test_;
+ using Tester<T,U>::device_;
// Uses several helper functions from the Tester class
- using Tester<T>::TestStart;
- using Tester<T>::TestEnd;
- using Tester<T>::TestSimilarity;
- using Tester<T>::TestErrorCount;
- using Tester<T>::TestErrorCodes;
- using Tester<T>::GetExampleScalars;
- using Tester<T>::GetOffsets;
- using Tester<T>::PrecisionSupported;
+ using Tester<T,U>::TestStart;
+ using Tester<T,U>::TestEnd;
+ using Tester<T,U>::TestErrorCount;
+ using Tester<T,U>::TestErrorCodes;
+ using Tester<T,U>::GetOffsets;
// Test settings for the regular test. Append to these lists in case more tests are required.
const std::vector<size_t> kVectorDims = { 7, 93, 4096 };
@@ -48,8 +49,8 @@ class TestBlas: public Tester<T> {
const std::vector<size_t> kMatrixDims = { 7, 64 };
const std::vector<size_t> kMatrixVectorDims = { 61, 512 };
const std::vector<size_t> kOffsets = GetOffsets();
- const std::vector<T> kAlphaValues = GetExampleScalars();
- const std::vector<T> kBetaValues = GetExampleScalars();
+ const std::vector<U> kAlphaValues = GetExampleScalars<U>(full_test_);
+ const std::vector<U> kBetaValues = GetExampleScalars<U>(full_test_);
// Test settings for the invalid tests
const std::vector<size_t> kInvalidIncrements = { 0, 1 };
@@ -65,10 +66,10 @@ class TestBlas: public Tester<T> {
static const std::vector<Transpose> kTransposes; // Data-type dependent, see .cc-file
// Shorthand for the routine-specific functions passed to the tester
- using Routine = std::function<StatusCode(const Arguments<T>&, const Buffers&, CommandQueue&)>;
- using ResultGet = std::function<std::vector<T>(const Arguments<T>&, Buffers&, CommandQueue&)>;
- using ResultIndex = std::function<size_t(const Arguments<T>&, const size_t, const size_t)>;
- using ResultIterator = std::function<size_t(const Arguments<T>&)>;
+ using Routine = std::function<StatusCode(const Arguments<U>&, const Buffers&, CommandQueue&)>;
+ using ResultGet = std::function<std::vector<T>(const Arguments<U>&, Buffers&, CommandQueue&)>;
+ using ResultIndex = std::function<size_t(const Arguments<U>&, const size_t, const size_t)>;
+ using ResultIterator = std::function<size_t(const Arguments<U>&)>;
// Constructor, initializes the base class tester and input data
TestBlas(int argc, char *argv[], const bool silent,
@@ -77,8 +78,8 @@ class TestBlas: public Tester<T> {
const ResultIndex get_index, const ResultIterator get_id1, const ResultIterator get_id2);
// The test functions, taking no inputs
- void TestRegular(std::vector<Arguments<T>> &test_vector, const std::string &name);
- void TestInvalid(std::vector<Arguments<T>> &test_vector, const std::string &name);
+ void TestRegular(std::vector<Arguments<U>> &test_vector, const std::string &name);
+ void TestInvalid(std::vector<Arguments<U>> &test_vector, const std::string &name);
private:
diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc
index db4ee619..378968ed 100644
--- a/test/correctness/tester.cc
+++ b/test/correctness/tester.cc
@@ -23,9 +23,9 @@ namespace clblast {
// General constructor for all CLBlast testers. It prints out the test header to stdout and sets-up
// the clBLAS library for reference.
-template <typename T>
-Tester<T>::Tester(int argc, char *argv[], const bool silent,
- const std::string &name, const std::vector<std::string> &options):
+template <typename T, typename U>
+Tester<T,U>::Tester(int argc, char *argv[], const bool silent,
+ const std::string &name, const std::vector<std::string> &options):
help_("Options given/available:\n"),
platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, size_t{0}))),
device_(Device(platform_, kDeviceType, GetArgument(argc, argv, help_, kArgDevice, size_t{0}))),
@@ -51,7 +51,7 @@ Tester<T>::Tester(int argc, char *argv[], const bool silent,
kPrintMessage.c_str(), name.c_str(), kPrintEnd.c_str());
// Checks whether the precision is supported
- if (!PrecisionSupported()) {
+ if (!PrecisionSupported<T>(device_)) {
fprintf(stdout, "\n* All tests skipped: %sUnsupported precision%s\n",
kPrintWarning.c_str(), kPrintEnd.c_str());
return;
@@ -76,9 +76,9 @@ Tester<T>::Tester(int argc, char *argv[], const bool silent,
}
// Destructor prints the summary of the test cases and cleans-up the clBLAS library
-template <typename T>
-Tester<T>::~Tester() {
- if (PrecisionSupported()) {
+template <typename T, typename U>
+Tester<T,U>::~Tester() {
+ if (PrecisionSupported<T>(device_)) {
fprintf(stdout, "* Completed all test-cases for this routine. Results:\n");
fprintf(stdout, " %lu test(s) passed\n", tests_passed_);
if (tests_skipped_ > 0) { fprintf(stdout, "%s", kPrintWarning.c_str()); }
@@ -94,8 +94,8 @@ Tester<T>::~Tester() {
// Function called at the start of each test. This prints a header with information about the
// test and re-initializes all test data-structures.
-template <typename T>
-void Tester<T>::TestStart(const std::string &test_name, const std::string &test_configuration) {
+template <typename T, typename U>
+void Tester<T,U>::TestStart(const std::string &test_name, const std::string &test_configuration) {
// Prints the header
fprintf(stdout, "* Testing %s'%s'%s for %s'%s'%s:\n",
@@ -113,8 +113,8 @@ void Tester<T>::TestStart(const std::string &test_name, const std::string &test_
// Function called at the end of each test. This prints errors if any occured. It also prints a
// summary of the number of sub-tests passed/failed.
-template <typename T>
-void Tester<T>::TestEnd() {
+template <typename T, typename U>
+void Tester<T,U>::TestEnd() {
fprintf(stdout, "\n");
tests_passed_ += num_passed_;
tests_failed_ += num_skipped_;
@@ -172,45 +172,9 @@ void Tester<T>::TestEnd() {
// =================================================================================================
-// Compares two floating point values and returns whether they are within an acceptable error
-// margin. This replaces GTest's EXPECT_NEAR().
-template <typename T>
-bool Tester<T>::TestSimilarity(const T val1, const T val2) {
- const auto difference = std::fabs(val1 - val2);
-
- // Shortcut, handles infinities
- if (val1 == val2) {
- return true;
- }
- // The values are zero or very small: the relative error is less meaningful
- else if (val1 == 0 || val2 == 0 || difference < static_cast<T>(kErrorMarginAbsolute)) {
- return (difference < static_cast<T>(kErrorMarginAbsolute));
- }
- // Use relative error
- else {
- return (difference / (std::fabs(val1)+std::fabs(val2))) < static_cast<T>(kErrorMarginRelative);
- }
-}
-
-// Specialisations for complex data-types
-template <>
-bool Tester<float2>::TestSimilarity(const float2 val1, const float2 val2) {
- auto real = Tester<float>::TestSimilarity(val1.real(), val2.real());
- auto imag = Tester<float>::TestSimilarity(val1.imag(), val2.imag());
- return (real && imag);
-}
-template <>
-bool Tester<double2>::TestSimilarity(const double2 val1, const double2 val2) {
- auto real = Tester<double>::TestSimilarity(val1.real(), val2.real());
- auto imag = Tester<double>::TestSimilarity(val1.imag(), val2.imag());
- return (real && imag);
-}
-
-// =================================================================================================
-
// Handles a 'pass' or 'error' depending on whether there are any errors
-template <typename T>
-void Tester<T>::TestErrorCount(const size_t errors, const size_t size, const Arguments<T> &args) {
+template <typename T, typename U>
+void Tester<T,U>::TestErrorCount(const size_t errors, const size_t size, const Arguments<U> &args) {
// Finished successfully
if (errors == 0) {
@@ -228,9 +192,9 @@ void Tester<T>::TestErrorCount(const size_t errors, const size_t size, const Arg
// Compares two status codes for equality. The outcome can be a pass (they are the same), a warning
// (CLBlast reported a compilation error), or an error (they are different).
-template <typename T>
-void Tester<T>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status,
- const Arguments<T> &args) {
+template <typename T, typename U>
+void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status,
+ const Arguments<U> &args) {
// Finished successfully
if (clblas_status == clblast_status) {
@@ -261,62 +225,26 @@ void Tester<T>::TestErrorCodes(const StatusCode clblas_status, const StatusCode
// =================================================================================================
-// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various
-// routines. This function is specialised for the different data-types.
-template <>
-const std::vector<float> Tester<float>::GetExampleScalars() {
- if (full_test_) { return {0.0f, 1.0f, 3.14f}; }
- else { return {3.14f}; }
-}
-template <>
-const std::vector<double> Tester<double>::GetExampleScalars() {
- if (full_test_) { return {0.0, 1.0, 3.14}; }
- else { return {3.14}; }
-}
-template <>
-const std::vector<float2> Tester<float2>::GetExampleScalars() {
- if (full_test_) { return {{0.0f, 0.0f}, {1.0f, 1.3f}, {2.42f, 3.14f}}; }
- else { return {{2.42f, 3.14f}}; }
-}
-template <>
-const std::vector<double2> Tester<double2>::GetExampleScalars() {
- if (full_test_) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; }
- else { return {{2.42, 3.14}}; }
-}
-
// Retrieves the offset values to test with
-template <typename T>
-const std::vector<size_t> Tester<T>::GetOffsets() {
+template <typename T, typename U>
+const std::vector<size_t> Tester<T,U>::GetOffsets() const {
if (full_test_) { return {0, 10}; }
else { return {0}; }
}
// =================================================================================================
-template <> bool Tester<float>::PrecisionSupported() const { return true; }
-template <> bool Tester<float2>::PrecisionSupported() const { return true; }
-template <> bool Tester<double>::PrecisionSupported() const {
- auto extensions = device_.Extensions();
- return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
-}
-template <> bool Tester<double2>::PrecisionSupported() const {
- auto extensions = device_.Extensions();
- return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
-}
-
-// =================================================================================================
-
// A test can either pass, be skipped, or fail
-template <typename T>
-void Tester<T>::ReportPass() {
+template <typename T, typename U>
+void Tester<T,U>::ReportPass() {
num_passed_++;
}
-template <typename T>
-void Tester<T>::ReportSkipped() {
+template <typename T, typename U>
+void Tester<T,U>::ReportSkipped() {
num_skipped_++;
}
-template <typename T>
-void Tester<T>::ReportError(const ErrorLogEntry &error_log_entry) {
+template <typename T, typename U>
+void Tester<T,U>::ReportError(const ErrorLogEntry &error_log_entry) {
error_log_.push_back(error_log_entry);
num_failed_++;
}
@@ -325,8 +253,8 @@ void Tester<T>::ReportError(const ErrorLogEntry &error_log_entry) {
// Prints the test-result symbol to screen. This function limits the maximum number of symbols per
// line by printing newlines once every so many calls.
-template <typename T>
-void Tester<T>::PrintTestResult(const std::string &message) {
+template <typename T, typename U>
+void Tester<T,U>::PrintTestResult(const std::string &message) {
if (print_count_ == kResultsPerLine) {
print_count_ = 0;
fprintf(stdout, "\n ");
@@ -337,12 +265,97 @@ void Tester<T>::PrintTestResult(const std::string &message) {
}
// =================================================================================================
+// Below are the non-member functions (separated because of otherwise required partial class
+// template specialization)
+// =================================================================================================
+
+// Compares two floating point values and returns whether they are within an acceptable error
+// margin. This replaces GTest's EXPECT_NEAR().
+template <typename T>
+bool TestSimilarity(const T val1, const T val2) {
+ const auto difference = std::fabs(val1 - val2);
+
+ // Set the allowed error margin for floating-point comparisons
+ constexpr auto kErrorMarginRelative = 1.0e-2;
+ constexpr auto kErrorMarginAbsolute = 1.0e-10;
+
+ // Shortcut, handles infinities
+ if (val1 == val2) {
+ return true;
+ }
+ // The values are zero or very small: the relative error is less meaningful
+ else if (val1 == 0 || val2 == 0 || difference < static_cast<T>(kErrorMarginAbsolute)) {
+ return (difference < static_cast<T>(kErrorMarginAbsolute));
+ }
+ // Use relative error
+ else {
+ const auto absolute_sum = std::fabs(val1) + std::fabs(val2);
+ return (difference / absolute_sum) < static_cast<T>(kErrorMarginRelative);
+ }
+}
+
+// Compiles the default case for non-complex data-types
+template bool TestSimilarity<float>(const float, const float);
+template bool TestSimilarity<double>(const double, const double);
+
+// Specialisations for complex data-types
+template <>
+bool TestSimilarity(const float2 val1, const float2 val2) {
+ auto real = TestSimilarity(val1.real(), val2.real());
+ auto imag = TestSimilarity(val1.imag(), val2.imag());
+ return (real && imag);
+}
+template <>
+bool TestSimilarity(const double2 val1, const double2 val2) {
+ auto real = TestSimilarity(val1.real(), val2.real());
+ auto imag = TestSimilarity(val1.imag(), val2.imag());
+ return (real && imag);
+}
+
+// =================================================================================================
+
+// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various
+// routines. This function is specialised for the different data-types.
+template <> const std::vector<float> GetExampleScalars(const bool full_test) {
+ if (full_test) { return {0.0f, 1.0f, 3.14f}; }
+ else { return {3.14f}; }
+}
+template <> const std::vector<double> GetExampleScalars(const bool full_test) {
+ if (full_test) { return {0.0, 1.0, 3.14}; }
+ else { return {3.14}; }
+}
+template <> const std::vector<float2> GetExampleScalars(const bool full_test) {
+ if (full_test) { return {{0.0f, 0.0f}, {1.0f, 1.3f}, {2.42f, 3.14f}}; }
+ else { return {{2.42f, 3.14f}}; }
+}
+template <> const std::vector<double2> GetExampleScalars(const bool full_test) {
+ if (full_test) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; }
+ else { return {{2.42, 3.14}}; }
+}
+
+// =================================================================================================
+
+// Returns false is this precision is not supported by the device
+template <> bool PrecisionSupported<float>(const Device &) { return true; }
+template <> bool PrecisionSupported<float2>(const Device &) { return true; }
+template <> bool PrecisionSupported<double>(const Device &device) {
+ auto extensions = device.Extensions();
+ return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
+}
+template <> bool PrecisionSupported<double2>(const Device &device) {
+ auto extensions = device.Extensions();
+ return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
+}
+
+// =================================================================================================
// Compiles the templated class
-template class Tester<float>;
-template class Tester<double>;
-template class Tester<float2>;
-template class Tester<double2>;
+template class Tester<float, float>;
+template class Tester<double, double>;
+template class Tester<float2, float2>;
+template class Tester<double2, double2>;
+template class Tester<float2, float>;
+template class Tester<double2, double>;
// =================================================================================================
} // namespace clblast
diff --git a/test/correctness/tester.h b/test/correctness/tester.h
index 9c4a9e86..93515138 100644
--- a/test/correctness/tester.h
+++ b/test/correctness/tester.h
@@ -10,6 +10,8 @@
// This file implements the Tester class, providing a test-framework. GTest was used before, but
// was not able to handle certain cases (e.g. template type + parameters). This is its (basic)
// custom replacement.
+// Typename T: the data-type of the routine's memory buffers (==precision)
+// Typename U: the data-type of the alpha and beta arguments
//
// =================================================================================================
@@ -30,7 +32,7 @@ namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
-template <typename T>
+template <typename T, typename U>
class Tester {
public:
@@ -43,10 +45,6 @@ class Tester {
// Error percentage is not applicable: error was caused by an incorrect status
static constexpr auto kStatusError = -1.0f;
- // Set the allowed error margin for floating-point comparisons
- static constexpr auto kErrorMarginRelative = 1.0e-2;
- static constexpr auto kErrorMarginAbsolute = 1.0e-10;
-
// Constants holding start and end strings for terminal-output in colour
const std::string kPrintError{"\x1b[31m"};
const std::string kPrintSuccess{"\x1b[32m"};
@@ -67,7 +65,7 @@ class Tester {
StatusCode status_expect;
StatusCode status_found;
float error_percentage;
- Arguments<T> args;
+ Arguments<U> args;
};
// Creates an instance of the tester, running on a particular OpenCL platform and device. It
@@ -80,25 +78,13 @@ class Tester {
void TestStart(const std::string &test_name, const std::string &test_configuration);
void TestEnd();
- // Compares two floating point values for similarity. Allows for a certain relative error margin.
- static bool TestSimilarity(const T val1, const T val2);
-
// Tests either an error count (should be zero) or two error codes (must match)
- void TestErrorCount(const size_t errors, const size_t size, const Arguments<T> &args);
+ void TestErrorCount(const size_t errors, const size_t size, const Arguments<U> &args);
void TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status,
- const Arguments<T> &args);
+ const Arguments<U> &args);
protected:
- // Retrieves a list of example scalars of the right type
- const std::vector<T> GetExampleScalars();
-
- // Retrieves a list of offset values to test
- const std::vector<size_t> GetOffsets();
-
- // Returns false is this precision is not supported by the device
- bool PrecisionSupported() const;
-
// The help-message
std::string help_;
@@ -108,6 +94,12 @@ class Tester {
Context context_;
CommandQueue queue_;
+ // Whether or not to run the full test-suite or just a smoke test
+ bool full_test_;
+
+ // Retrieves the offset values to test with
+ const std::vector<size_t> GetOffsets() const;
+
private:
// Internal methods to report a passed, skipped, or failed test
@@ -118,9 +110,6 @@ class Tester {
// Prints the error or success symbol to screen
void PrintTestResult(const std::string &message);
- // Whether or not to run the full test-suite or just a smoke test
- bool full_test_;
-
// Logging and counting occurrences of errors
std::vector<ErrorLogEntry> error_log_;
size_t num_passed_;
@@ -140,6 +129,25 @@ class Tester {
};
// =================================================================================================
+// Below are the non-member functions (separated because of otherwise required partial class
+// template specialization)
+// =================================================================================================
+
+// Compares two floating point values and returns whether they are within an acceptable error
+// margin. This replaces GTest's EXPECT_NEAR().
+template <typename T>
+bool TestSimilarity(const T val1, const T val2);
+
+// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various
+// routines. This function is specialised for the different data-types.
+template <typename T>
+const std::vector<T> GetExampleScalars(const bool full_test);
+
+// Returns false is this precision is not supported by the device
+template <typename T>
+bool PrecisionSupported(const Device &device);
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_TEST_CORRECTNESS_TESTER_H_