diff options
Diffstat (limited to 'test/correctness/tester.cc')
-rw-r--r-- | test/correctness/tester.cc | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc index 51d83362..82926c3c 100644 --- a/test/correctness/tester.cc +++ b/test/correctness/tester.cc @@ -43,9 +43,32 @@ Tester<T,U>::Tester(int argc, char *argv[], const bool silent, tests_failed_{0}, options_{options} { + // Determines which reference to test against + #if defined(CLBLAST_REF_CLBLAS) && defined(CLBLAST_REF_CBLAS) + compare_clblas_ = GetArgument(argc, argv, help_, kArgCompareclblas, 1); + compare_cblas_ = GetArgument(argc, argv, help_, kArgComparecblas, 0); + #elif CLBLAST_REF_CLBLAS + compare_clblas_ = GetArgument(argc, argv, help_, kArgCompareclblas, 1); + compare_cblas_ = 0; + #elif CLBLAST_REF_CBLAS + compare_clblas_ = 0; + compare_cblas_ = GetArgument(argc, argv, help_, kArgComparecblas, 1); + #else + compare_clblas_ = 0; + compare_cblas_ = 0; + #endif + // Prints the help message (command-line arguments) if (!silent) { fprintf(stdout, "\n* %s\n", help_.c_str()); } + // Can only test against a single reference (not two, not zero) + if (compare_clblas_ && compare_cblas_) { + throw std::runtime_error("Cannot test against both clBLAS and CBLAS references; choose one using the -cblas and -clblas arguments"); + } + if (!compare_clblas_ && !compare_cblas_) { + throw std::runtime_error("Choose one reference (clBLAS or CBLAS) to test against using the -cblas and -clblas arguments"); + } + // Prints the header fprintf(stdout, "* Running on OpenCL device '%s'.\n", device_.Name().c_str()); fprintf(stdout, "* Starting tests for the %s'%s'%s routine.", @@ -68,12 +91,16 @@ Tester<T,U>::Tester(int argc, char *argv[], const bool silent, kSkippedCompilation.c_str()); fprintf(stdout, " %s -> Test not executed: Unsupported precision\n", kUnsupportedPrecision.c_str()); + fprintf(stdout, " %s -> Test not completed: Reference CBLAS doesn't output error codes\n", + kUnsupportedReference.c_str()); // Initializes clBLAS #ifdef CLBLAST_REF_CLBLAS - auto status = clblasSetup(); - if (status != CL_SUCCESS) { - throw std::runtime_error("clBLAS setup error: "+ToString(static_cast<int>(status))); + if (compare_clblas_) { + auto status = clblasSetup(); + if (status != CL_SUCCESS) { + throw std::runtime_error("clBLAS setup error: "+ToString(static_cast<int>(status))); + } } #endif } @@ -93,7 +120,9 @@ Tester<T,U>::~Tester() { // Cleans-up clBLAS #ifdef CLBLAST_REF_CLBLAS - clblasTeardown(); + if (compare_clblas_) { + clblasTeardown(); + } #endif } @@ -124,7 +153,7 @@ template <typename T, typename U> void Tester<T,U>::TestEnd() { fprintf(stdout, "\n"); tests_passed_ += num_passed_; - tests_failed_ += num_skipped_; + tests_skipped_ += num_skipped_; tests_failed_ += num_failed_; // Prints the errors @@ -174,8 +203,14 @@ template <typename T, typename U> void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status, const Arguments<U> &args) { + // Cannot compare error codes against a library other than clBLAS + if (compare_cblas_) { + PrintTestResult(kUnsupportedReference); + ReportSkipped(); + } + // Finished successfully - if (clblas_status == clblast_status) { + else if (clblas_status == clblast_status) { PrintTestResult(kSuccessStatus); ReportPass(); } |