diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-04-01 13:36:24 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-04-01 13:36:24 +0200 |
commit | b84d2296b87ac212474af855d916b12adf96bdb7 (patch) | |
tree | 0f2e85e1e1acef1d22f046499dd0b8a30e5da4f9 /test/correctness | |
parent | a98c00a2671b8981579f3a73dca8fb3365a95e53 (diff) |
Separated host-device and device-host memory copies from execution of the CBLAS reference code; for fair timing and code de-duplication
Diffstat (limited to 'test/correctness')
-rw-r--r-- | test/correctness/testblas.cpp | 18 | ||||
-rw-r--r-- | test/correctness/testblas.hpp | 49 |
2 files changed, 44 insertions, 23 deletions
diff --git a/test/correctness/testblas.cpp b/test/correctness/testblas.cpp index c8c59fcf..1bfcb623 100644 --- a/test/correctness/testblas.cpp +++ b/test/correctness/testblas.cpp @@ -67,15 +67,17 @@ TestBlas<T,U>::TestBlas(const std::vector<std::string> &arguments, const bool si kBetaValues(GetExampleScalars<U>(full_test_)), prepare_data_(prepare_data), run_routine_(run_routine), + run_reference1_(run_reference1), + run_reference2_(run_reference2), get_result_(get_result), get_index_(get_index), get_id1_(get_id1), get_id2_(get_id2) { - // Sets the reference to test against - if (compare_clblas_) { run_reference_ = run_reference1; } - else if (compare_cblas_) { run_reference_ = run_reference2; } - else { throw std::runtime_error("Invalid configuration: no reference to test against"); } + // Sanity check + if (!compare_clblas_ && !compare_cblas_) { + throw std::runtime_error("Invalid configuration: no reference to test against"); + } // Computes the maximum sizes. This allows for a single set of input/output buffers. const auto max_vec = *std::max_element(kVectorDims.begin(), kVectorDims.end()); @@ -184,7 +186,9 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st else if (compare_cblas_) { fprintf(stdout, " [CPU BLAS]"); } std::cout << std::flush; } - const auto status1 = run_reference_(args, buffers1, queue_); + auto status1 = StatusCode::kSuccess; + if (compare_clblas_) { status1 = run_reference1_(args, buffers1, queue_); } + else if (compare_cblas_) { status1 = run_reference2_(args, buffers1, queue_); } // Tests for equality of the two status codes if (verbose_) { fprintf(stdout, " -> "); std::cout << std::flush; } @@ -305,7 +309,9 @@ void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const st else if (compare_cblas_) { fprintf(stdout, " [CPU BLAS]"); } std::cout << std::flush; } - const auto status1 = run_reference_(args, buffers1, queue_); + auto status1 = StatusCode::kSuccess; + if (compare_clblas_) { status1 = run_reference1_(args, buffers1, queue_); } + else if (compare_cblas_) { status1 = run_reference2_(args, buffers1, queue_); } // Tests for equality of the two status codes if (verbose_) { fprintf(stdout, " -> "); std::cout << std::flush; } diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 8c8db348..560ff4d3 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -109,33 +109,48 @@ class TestBlas: public Tester<T,U> { std::vector<T> scalar_source_; // The routine-specific functions passed to the tester - DataPrepare prepare_data_; - Routine run_routine_; - Routine run_reference_; - ResultGet get_result_; - ResultIndex get_index_; - ResultIterator get_id1_; - ResultIterator get_id2_; + const DataPrepare prepare_data_; + const Routine run_routine_; + const Routine run_reference1_; + const Routine run_reference2_; + const ResultGet get_result_; + const ResultIndex get_index_; + const ResultIterator get_id1_; + const ResultIterator get_id2_; }; // ================================================================================================= +// Bogus reference function, in case a comparison library is not available +template <typename T, typename U, typename BufferType> +static StatusCode ReferenceNotAvailable(const Arguments<U> &, BufferType &, Queue &) { + return StatusCode::kNotImplemented; +} + // The interface to the correctness tester. This is a separate function in the header such that it // is automatically compiled for each routine, templated by the parameter "C". template <typename C, typename T, typename U> size_t RunTests(int argc, char *argv[], const bool silent, const std::string &name) { auto command_line_args = RetrieveCommandLineArguments(argc, argv); - // Sets the reference to test against - #if defined(CLBLAST_REF_CLBLAS) && defined(CLBLAST_REF_CBLAS) - const auto reference_routine1 = C::RunReference1; // clBLAS - const auto reference_routine2 = C::RunReference2; // CBLAS - #elif CLBLAST_REF_CLBLAS - const auto reference_routine1 = C::RunReference1; // clBLAS - const auto reference_routine2 = C::RunReference1; // not used, dummy - #elif CLBLAST_REF_CBLAS - const auto reference_routine1 = C::RunReference2; // not used, dummy - const auto reference_routine2 = C::RunReference2; // CBLAS + // Sets the clBLAS reference to test against + #ifdef CLBLAST_REF_CLBLAS + auto reference_routine1 = C::RunReference1; // clBLAS when available + #else + auto reference_routine1 = ReferenceNotAvailable<T,U,Buffers<T>>; + #endif + + // Sets the CBLAS reference to test against + #ifdef CLBLAST_REF_CBLAS + auto reference_routine2 = [](const Arguments<U> &args, Buffers<T> &buffers, Queue &queue) -> StatusCode { + auto buffers_host = BuffersHost<T>(); + DeviceToHost(args, buffers, buffers_host, queue, C::BuffersIn()); + C::RunReference2(args, buffers_host, queue); + HostToDevice(args, buffers, buffers_host, queue, C::BuffersOut()); + return StatusCode::kSuccess; + }; + #else + auto reference_routine2 = ReferenceNotAvailable<T,U,Buffers<T>>; #endif // Non-BLAS routines cannot be fully tested |