summaryrefslogtreecommitdiff
path: root/test/correctness
diff options
context:
space:
mode:
Diffstat (limited to 'test/correctness')
-rw-r--r--test/correctness/testblas.cpp18
-rw-r--r--test/correctness/testblas.hpp49
2 files changed, 44 insertions, 23 deletions
diff --git a/test/correctness/testblas.cpp b/test/correctness/testblas.cpp
index c8c59fcf..1bfcb623 100644
--- a/test/correctness/testblas.cpp
+++ b/test/correctness/testblas.cpp
@@ -67,15 +67,17 @@ TestBlas<T,U>::TestBlas(const std::vector<std::string> &arguments, const bool si
kBetaValues(GetExampleScalars<U>(full_test_)),
prepare_data_(prepare_data),
run_routine_(run_routine),
+ run_reference1_(run_reference1),
+ run_reference2_(run_reference2),
get_result_(get_result),
get_index_(get_index),
get_id1_(get_id1),
get_id2_(get_id2) {
- // Sets the reference to test against
- if (compare_clblas_) { run_reference_ = run_reference1; }
- else if (compare_cblas_) { run_reference_ = run_reference2; }
- else { throw std::runtime_error("Invalid configuration: no reference to test against"); }
+ // Sanity check
+ if (!compare_clblas_ && !compare_cblas_) {
+ throw std::runtime_error("Invalid configuration: no reference to test against");
+ }
// Computes the maximum sizes. This allows for a single set of input/output buffers.
const auto max_vec = *std::max_element(kVectorDims.begin(), kVectorDims.end());
@@ -184,7 +186,9 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st
else if (compare_cblas_) { fprintf(stdout, " [CPU BLAS]"); }
std::cout << std::flush;
}
- const auto status1 = run_reference_(args, buffers1, queue_);
+ auto status1 = StatusCode::kSuccess;
+ if (compare_clblas_) { status1 = run_reference1_(args, buffers1, queue_); }
+ else if (compare_cblas_) { status1 = run_reference2_(args, buffers1, queue_); }
// Tests for equality of the two status codes
if (verbose_) { fprintf(stdout, " -> "); std::cout << std::flush; }
@@ -305,7 +309,9 @@ void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const st
else if (compare_cblas_) { fprintf(stdout, " [CPU BLAS]"); }
std::cout << std::flush;
}
- const auto status1 = run_reference_(args, buffers1, queue_);
+ auto status1 = StatusCode::kSuccess;
+ if (compare_clblas_) { status1 = run_reference1_(args, buffers1, queue_); }
+ else if (compare_cblas_) { status1 = run_reference2_(args, buffers1, queue_); }
// Tests for equality of the two status codes
if (verbose_) { fprintf(stdout, " -> "); std::cout << std::flush; }
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index 8c8db348..560ff4d3 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -109,33 +109,48 @@ class TestBlas: public Tester<T,U> {
std::vector<T> scalar_source_;
// The routine-specific functions passed to the tester
- DataPrepare prepare_data_;
- Routine run_routine_;
- Routine run_reference_;
- ResultGet get_result_;
- ResultIndex get_index_;
- ResultIterator get_id1_;
- ResultIterator get_id2_;
+ const DataPrepare prepare_data_;
+ const Routine run_routine_;
+ const Routine run_reference1_;
+ const Routine run_reference2_;
+ const ResultGet get_result_;
+ const ResultIndex get_index_;
+ const ResultIterator get_id1_;
+ const ResultIterator get_id2_;
};
// =================================================================================================
+// Bogus reference function, in case a comparison library is not available
+template <typename T, typename U, typename BufferType>
+static StatusCode ReferenceNotAvailable(const Arguments<U> &, BufferType &, Queue &) {
+ return StatusCode::kNotImplemented;
+}
+
// The interface to the correctness tester. This is a separate function in the header such that it
// is automatically compiled for each routine, templated by the parameter "C".
template <typename C, typename T, typename U>
size_t RunTests(int argc, char *argv[], const bool silent, const std::string &name) {
auto command_line_args = RetrieveCommandLineArguments(argc, argv);
- // Sets the reference to test against
- #if defined(CLBLAST_REF_CLBLAS) && defined(CLBLAST_REF_CBLAS)
- const auto reference_routine1 = C::RunReference1; // clBLAS
- const auto reference_routine2 = C::RunReference2; // CBLAS
- #elif CLBLAST_REF_CLBLAS
- const auto reference_routine1 = C::RunReference1; // clBLAS
- const auto reference_routine2 = C::RunReference1; // not used, dummy
- #elif CLBLAST_REF_CBLAS
- const auto reference_routine1 = C::RunReference2; // not used, dummy
- const auto reference_routine2 = C::RunReference2; // CBLAS
+ // Sets the clBLAS reference to test against
+ #ifdef CLBLAST_REF_CLBLAS
+ auto reference_routine1 = C::RunReference1; // clBLAS when available
+ #else
+ auto reference_routine1 = ReferenceNotAvailable<T,U,Buffers<T>>;
+ #endif
+
+ // Sets the CBLAS reference to test against
+ #ifdef CLBLAST_REF_CBLAS
+ auto reference_routine2 = [](const Arguments<U> &args, Buffers<T> &buffers, Queue &queue) -> StatusCode {
+ auto buffers_host = BuffersHost<T>();
+ DeviceToHost(args, buffers, buffers_host, queue, C::BuffersIn());
+ C::RunReference2(args, buffers_host, queue);
+ HostToDevice(args, buffers, buffers_host, queue, C::BuffersOut());
+ return StatusCode::kSuccess;
+ };
+ #else
+ auto reference_routine2 = ReferenceNotAvailable<T,U,Buffers<T>>;
#endif
// Non-BLAS routines cannot be fully tested