summaryrefslogtreecommitdiff
path: root/test/correctness/testblas.h
diff options
context:
space:
mode:
Diffstat (limited to 'test/correctness/testblas.h')
-rw-r--r--test/correctness/testblas.h47
1 files changed, 30 insertions, 17 deletions
diff --git a/test/correctness/testblas.h b/test/correctness/testblas.h
index 7c9032bd..8181aaf6 100644
--- a/test/correctness/testblas.h
+++ b/test/correctness/testblas.h
@@ -68,7 +68,7 @@ class TestBlas: public Tester<T,U> {
static const std::vector<Transpose> kTransposes; // Data-type dependent, see .cc-file
// Shorthand for the routine-specific functions passed to the tester
- using Routine = std::function<StatusCode(const Arguments<U>&, const Buffers<T>&, Queue&)>;
+ using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
using ResultGet = std::function<std::vector<T>(const Arguments<U>&, Buffers<T>&, Queue&)>;
using ResultIndex = std::function<size_t(const Arguments<U>&, const size_t, const size_t)>;
using ResultIterator = std::function<size_t(const Arguments<U>&)>;
@@ -76,8 +76,9 @@ class TestBlas: public Tester<T,U> {
// Constructor, initializes the base class tester and input data
TestBlas(int argc, char *argv[], const bool silent,
const std::string &name, const std::vector<std::string> &options,
- const Routine run_routine, const Routine run_reference, const ResultGet get_result,
- const ResultIndex get_index, const ResultIterator get_id1, const ResultIterator get_id2);
+ const Routine run_routine, const Routine run_reference,
+ const ResultGet get_result, const ResultIndex get_index,
+ const ResultIterator get_id1, const ResultIterator get_id2);
// The test functions, taking no inputs
void TestRegular(std::vector<Arguments<U>> &test_vector, const std::string &name);
@@ -110,9 +111,17 @@ class TestBlas: public Tester<T,U> {
template <typename C, typename T, typename U>
void RunTests(int argc, char *argv[], const bool silent, const std::string &name) {
+ // Sets the reference to test against
+ #ifdef CLBLAST_REF_CLBLAS
+ const auto reference_routine = C::RunReference1; // clBLAS when available
+ #else
+ const auto reference_routine = C::RunReference2; // otherwise CBLAS
+ #endif
+
// Creates a tester
auto options = C::GetOptions();
- TestBlas<T,U> tester{argc, argv, silent, name, options, C::RunRoutine, C::RunReference,
+ TestBlas<T,U> tester{argc, argv, silent, name, options,
+ C::RunRoutine, reference_routine,
C::DownloadResult, C::GetResultIndex, C::ResultID1, C::ResultID2};
// This variable holds the arguments relevant for this routine
@@ -250,23 +259,25 @@ void RunTests(int argc, char *argv[], const bool silent, const std::string &name
}
// Creates the arguments vector for the invalid-buffer tests
- auto invalid_test_vector = std::vector<Arguments<U>>{};
- auto i_args = args;
- i_args.m = i_args.n = i_args.k = i_args.kl = i_args.ku = tester.kBufferSize;
- i_args.a_ld = i_args.b_ld = i_args.c_ld = tester.kBufferSize;
- for (auto &x_size: x_sizes) { i_args.x_size = x_size;
- for (auto &y_size: y_sizes) { i_args.y_size = y_size;
- for (auto &a_size: a_sizes) { i_args.a_size = a_size;
- for (auto &b_size: b_sizes) { i_args.b_size = b_size;
- for (auto &c_size: c_sizes) { i_args.c_size = c_size;
- for (auto &ap_size: ap_sizes) { i_args.ap_size = ap_size;
- invalid_test_vector.push_back(i_args);
+ #ifdef CLBLAST_REF_CLBLAS
+ auto invalid_test_vector = std::vector<Arguments<U>>{};
+ auto i_args = args;
+ i_args.m = i_args.n = i_args.k = i_args.kl = i_args.ku = tester.kBufferSize;
+ i_args.a_ld = i_args.b_ld = i_args.c_ld = tester.kBufferSize;
+ for (auto &x_size: x_sizes) { i_args.x_size = x_size;
+ for (auto &y_size: y_sizes) { i_args.y_size = y_size;
+ for (auto &a_size: a_sizes) { i_args.a_size = a_size;
+ for (auto &b_size: b_sizes) { i_args.b_size = b_size;
+ for (auto &c_size: c_sizes) { i_args.c_size = c_size;
+ for (auto &ap_size: ap_sizes) { i_args.ap_size = ap_size;
+ invalid_test_vector.push_back(i_args);
+ }
}
}
}
}
}
- }
+ #endif
// Sets the name of this test-case
auto names = std::vector<std::string>{};
@@ -287,7 +298,9 @@ void RunTests(int argc, char *argv[], const bool silent, const std::string &name
// Runs the tests
tester.TestRegular(regular_test_vector, case_name);
- tester.TestInvalid(invalid_test_vector, case_name);
+ #ifdef CLBLAST_REF_CLBLAS
+ tester.TestInvalid(invalid_test_vector, case_name);
+ #endif
}
}
}