diff options
Diffstat (limited to 'test/correctness/testblas.h')
-rw-r--r-- | test/correctness/testblas.h | 47 |
1 files changed, 30 insertions, 17 deletions
diff --git a/test/correctness/testblas.h b/test/correctness/testblas.h index 7c9032bd..8181aaf6 100644 --- a/test/correctness/testblas.h +++ b/test/correctness/testblas.h @@ -68,7 +68,7 @@ class TestBlas: public Tester<T,U> { static const std::vector<Transpose> kTransposes; // Data-type dependent, see .cc-file // Shorthand for the routine-specific functions passed to the tester - using Routine = std::function<StatusCode(const Arguments<U>&, const Buffers<T>&, Queue&)>; + using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>; using ResultGet = std::function<std::vector<T>(const Arguments<U>&, Buffers<T>&, Queue&)>; using ResultIndex = std::function<size_t(const Arguments<U>&, const size_t, const size_t)>; using ResultIterator = std::function<size_t(const Arguments<U>&)>; @@ -76,8 +76,9 @@ class TestBlas: public Tester<T,U> { // Constructor, initializes the base class tester and input data TestBlas(int argc, char *argv[], const bool silent, const std::string &name, const std::vector<std::string> &options, - const Routine run_routine, const Routine run_reference, const ResultGet get_result, - const ResultIndex get_index, const ResultIterator get_id1, const ResultIterator get_id2); + const Routine run_routine, const Routine run_reference, + const ResultGet get_result, const ResultIndex get_index, + const ResultIterator get_id1, const ResultIterator get_id2); // The test functions, taking no inputs void TestRegular(std::vector<Arguments<U>> &test_vector, const std::string &name); @@ -110,9 +111,17 @@ class TestBlas: public Tester<T,U> { template <typename C, typename T, typename U> void RunTests(int argc, char *argv[], const bool silent, const std::string &name) { + // Sets the reference to test against + #ifdef CLBLAST_REF_CLBLAS + const auto reference_routine = C::RunReference1; // clBLAS when available + #else + const auto reference_routine = C::RunReference2; // otherwise CBLAS + #endif + // Creates a tester auto options = C::GetOptions(); - TestBlas<T,U> tester{argc, argv, silent, name, options, C::RunRoutine, C::RunReference, + TestBlas<T,U> tester{argc, argv, silent, name, options, + C::RunRoutine, reference_routine, C::DownloadResult, C::GetResultIndex, C::ResultID1, C::ResultID2}; // This variable holds the arguments relevant for this routine @@ -250,23 +259,25 @@ void RunTests(int argc, char *argv[], const bool silent, const std::string &name } // Creates the arguments vector for the invalid-buffer tests - auto invalid_test_vector = std::vector<Arguments<U>>{}; - auto i_args = args; - i_args.m = i_args.n = i_args.k = i_args.kl = i_args.ku = tester.kBufferSize; - i_args.a_ld = i_args.b_ld = i_args.c_ld = tester.kBufferSize; - for (auto &x_size: x_sizes) { i_args.x_size = x_size; - for (auto &y_size: y_sizes) { i_args.y_size = y_size; - for (auto &a_size: a_sizes) { i_args.a_size = a_size; - for (auto &b_size: b_sizes) { i_args.b_size = b_size; - for (auto &c_size: c_sizes) { i_args.c_size = c_size; - for (auto &ap_size: ap_sizes) { i_args.ap_size = ap_size; - invalid_test_vector.push_back(i_args); + #ifdef CLBLAST_REF_CLBLAS + auto invalid_test_vector = std::vector<Arguments<U>>{}; + auto i_args = args; + i_args.m = i_args.n = i_args.k = i_args.kl = i_args.ku = tester.kBufferSize; + i_args.a_ld = i_args.b_ld = i_args.c_ld = tester.kBufferSize; + for (auto &x_size: x_sizes) { i_args.x_size = x_size; + for (auto &y_size: y_sizes) { i_args.y_size = y_size; + for (auto &a_size: a_sizes) { i_args.a_size = a_size; + for (auto &b_size: b_sizes) { i_args.b_size = b_size; + for (auto &c_size: c_sizes) { i_args.c_size = c_size; + for (auto &ap_size: ap_sizes) { i_args.ap_size = ap_size; + invalid_test_vector.push_back(i_args); + } } } } } } - } + #endif // Sets the name of this test-case auto names = std::vector<std::string>{}; @@ -287,7 +298,9 @@ void RunTests(int argc, char *argv[], const bool silent, const std::string &name // Runs the tests tester.TestRegular(regular_test_vector, case_name); - tester.TestInvalid(invalid_test_vector, case_name); + #ifdef CLBLAST_REF_CLBLAS + tester.TestInvalid(invalid_test_vector, case_name); + #endif } } } |