summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-07 12:22:06 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-07 12:22:06 +0200
commit6c9e08c5e288767d9afedb118c37694f63739cae (patch)
treebf27e64913de3064b8c37988db1252d389dbe22c /test
parent56aa1701c955546e049ec0dbe5b2777d592b5fc1 (diff)
Added an option to the tests to control whether to test against clBLAS or a CPU BLAS library
Diffstat (limited to 'test')
-rw-r--r--test/correctness/testblas.cc23
-rw-r--r--test/correctness/testblas.h20
-rw-r--r--test/correctness/tester.cc47
-rw-r--r--test/correctness/tester.h5
4 files changed, 73 insertions, 22 deletions
diff --git a/test/correctness/testblas.cc b/test/correctness/testblas.cc
index 6bcba267..1f83c59b 100644
--- a/test/correctness/testblas.cc
+++ b/test/correctness/testblas.cc
@@ -33,17 +33,22 @@ template <> const std::vector<Transpose> TestBlas<double2,double>::kTransposes =
template <typename T, typename U>
TestBlas<T,U>::TestBlas(int argc, char *argv[], const bool silent,
const std::string &name, const std::vector<std::string> &options,
- const Routine run_routine, const Routine run_reference,
+ const Routine run_routine,
+ const Routine run_reference1, const Routine run_reference2,
const ResultGet get_result, const ResultIndex get_index,
const ResultIterator get_id1, const ResultIterator get_id2):
Tester<T,U>(argc, argv, silent, name, options),
run_routine_(run_routine),
- run_reference_(run_reference),
get_result_(get_result),
get_index_(get_index),
get_id1_(get_id1),
get_id2_(get_id2) {
+ // Sets the reference to test against
+ if (compare_clblas_) { run_reference_ = run_reference1; }
+ else if (compare_cblas_) { run_reference_ = run_reference2; }
+ else { throw std::runtime_error("Invalid configuration: no reference to test against"); }
+
// Computes the maximum sizes. This allows for a single set of input/output buffers.
auto max_vec = *std::max_element(kVectorDims.begin(), kVectorDims.end());
auto max_inc = *std::max_element(kIncrements.begin(), kIncrements.end());
@@ -98,14 +103,11 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st
auto buffers2 = Buffers<T>{x_vec2, y_vec2, a_mat2, b_mat2, c_mat2, ap_mat2, scalar2};
auto status2 = run_routine_(args, buffers2, queue_);
- #ifndef CLBLAST_REF_CLBLAS
- // Don't continue with CBLAS if there are incorrect parameters
- if (status2 != StatusCode::kSuccess) {
- // TODO: Mark this as a skipped test instead of a succesfull test
- TestErrorCodes(status2, status2, args);
- continue;
- }
- #endif
+ // Don't continue with CBLAS if there are incorrect parameters
+ if (compare_cblas_ && status2 != StatusCode::kSuccess) {
+ TestErrorCodes(status2, status2, args);
+ continue;
+ }
// Runs the reference BLAS code
auto x_vec1 = Buffer<T>(context_, args.x_size);
@@ -168,6 +170,7 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st
template <typename T, typename U>
void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const std::string &name) {
if (!PrecisionSupported<T>(device_)) { return; }
+ if (!compare_clblas_) { return; }
TestStart("invalid buffer sizes", name);
// Iterates over all the to-be-tested combinations of arguments
diff --git a/test/correctness/testblas.h b/test/correctness/testblas.h
index 8fd1b1e2..4ffc1558 100644
--- a/test/correctness/testblas.h
+++ b/test/correctness/testblas.h
@@ -37,6 +37,8 @@ class TestBlas: public Tester<T,U> {
using Tester<T,U>::full_test_;
using Tester<T,U>::verbose_;
using Tester<T,U>::device_;
+ using Tester<T,U>::compare_clblas_;
+ using Tester<T,U>::compare_cblas_;
// Uses several helper functions from the Tester class
using Tester<T,U>::TestStart;
@@ -77,7 +79,8 @@ class TestBlas: public Tester<T,U> {
// Constructor, initializes the base class tester and input data
TestBlas(int argc, char *argv[], const bool silent,
const std::string &name, const std::vector<std::string> &options,
- const Routine run_routine, const Routine run_reference,
+ const Routine run_routine,
+ const Routine run_reference1, const Routine run_reference2,
const ResultGet get_result, const ResultIndex get_index,
const ResultIterator get_id1, const ResultIterator get_id2);
@@ -113,16 +116,21 @@ template <typename C, typename T, typename U>
void RunTests(int argc, char *argv[], const bool silent, const std::string &name) {
// Sets the reference to test against
- #ifdef CLBLAST_REF_CLBLAS
- const auto reference_routine = C::RunReference1; // clBLAS when available
- #else
- const auto reference_routine = C::RunReference2; // otherwise CBLAS
+ #if defined(CLBLAST_REF_CLBLAS) && defined(CLBLAST_REF_CBLAS)
+ const auto reference_routine1 = C::RunReference1; // clBLAS
+ const auto reference_routine2 = C::RunReference2; // CBLAS
+ #elif CLBLAST_REF_CLBLAS
+ const auto reference_routine1 = C::RunReference1; // clBLAS
+ const auto reference_routine2 = C::RunReference1; // not used, dummy
+ #elif CLBLAST_REF_CBLAS
+ const auto reference_routine1 = C::RunReference2; // not used, dummy
+ const auto reference_routine2 = C::RunReference2; // CBLAS
#endif
// Creates a tester
auto options = C::GetOptions();
TestBlas<T,U> tester{argc, argv, silent, name, options,
- C::RunRoutine, reference_routine,
+ C::RunRoutine, reference_routine1, reference_routine2,
C::DownloadResult, C::GetResultIndex, C::ResultID1, C::ResultID2};
// This variable holds the arguments relevant for this routine
diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc
index 51d83362..82926c3c 100644
--- a/test/correctness/tester.cc
+++ b/test/correctness/tester.cc
@@ -43,9 +43,32 @@ Tester<T,U>::Tester(int argc, char *argv[], const bool silent,
tests_failed_{0},
options_{options} {
+ // Determines which reference to test against
+ #if defined(CLBLAST_REF_CLBLAS) && defined(CLBLAST_REF_CBLAS)
+ compare_clblas_ = GetArgument(argc, argv, help_, kArgCompareclblas, 1);
+ compare_cblas_ = GetArgument(argc, argv, help_, kArgComparecblas, 0);
+ #elif CLBLAST_REF_CLBLAS
+ compare_clblas_ = GetArgument(argc, argv, help_, kArgCompareclblas, 1);
+ compare_cblas_ = 0;
+ #elif CLBLAST_REF_CBLAS
+ compare_clblas_ = 0;
+ compare_cblas_ = GetArgument(argc, argv, help_, kArgComparecblas, 1);
+ #else
+ compare_clblas_ = 0;
+ compare_cblas_ = 0;
+ #endif
+
// Prints the help message (command-line arguments)
if (!silent) { fprintf(stdout, "\n* %s\n", help_.c_str()); }
+ // Can only test against a single reference (not two, not zero)
+ if (compare_clblas_ && compare_cblas_) {
+ throw std::runtime_error("Cannot test against both clBLAS and CBLAS references; choose one using the -cblas and -clblas arguments");
+ }
+ if (!compare_clblas_ && !compare_cblas_) {
+ throw std::runtime_error("Choose one reference (clBLAS or CBLAS) to test against using the -cblas and -clblas arguments");
+ }
+
// Prints the header
fprintf(stdout, "* Running on OpenCL device '%s'.\n", device_.Name().c_str());
fprintf(stdout, "* Starting tests for the %s'%s'%s routine.",
@@ -68,12 +91,16 @@ Tester<T,U>::Tester(int argc, char *argv[], const bool silent,
kSkippedCompilation.c_str());
fprintf(stdout, " %s -> Test not executed: Unsupported precision\n",
kUnsupportedPrecision.c_str());
+ fprintf(stdout, " %s -> Test not completed: Reference CBLAS doesn't output error codes\n",
+ kUnsupportedReference.c_str());
// Initializes clBLAS
#ifdef CLBLAST_REF_CLBLAS
- auto status = clblasSetup();
- if (status != CL_SUCCESS) {
- throw std::runtime_error("clBLAS setup error: "+ToString(static_cast<int>(status)));
+ if (compare_clblas_) {
+ auto status = clblasSetup();
+ if (status != CL_SUCCESS) {
+ throw std::runtime_error("clBLAS setup error: "+ToString(static_cast<int>(status)));
+ }
}
#endif
}
@@ -93,7 +120,9 @@ Tester<T,U>::~Tester() {
// Cleans-up clBLAS
#ifdef CLBLAST_REF_CLBLAS
- clblasTeardown();
+ if (compare_clblas_) {
+ clblasTeardown();
+ }
#endif
}
@@ -124,7 +153,7 @@ template <typename T, typename U>
void Tester<T,U>::TestEnd() {
fprintf(stdout, "\n");
tests_passed_ += num_passed_;
- tests_failed_ += num_skipped_;
+ tests_skipped_ += num_skipped_;
tests_failed_ += num_failed_;
// Prints the errors
@@ -174,8 +203,14 @@ template <typename T, typename U>
void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status,
const Arguments<U> &args) {
+ // Cannot compare error codes against a library other than clBLAS
+ if (compare_cblas_) {
+ PrintTestResult(kUnsupportedReference);
+ ReportSkipped();
+ }
+
// Finished successfully
- if (clblas_status == clblast_status) {
+ else if (clblas_status == clblast_status) {
PrintTestResult(kSuccessStatus);
ReportPass();
}
diff --git a/test/correctness/tester.h b/test/correctness/tester.h
index 3534dffb..46d88caf 100644
--- a/test/correctness/tester.h
+++ b/test/correctness/tester.h
@@ -58,6 +58,7 @@ class Tester {
const std::string kErrorStatus{kPrintError + "/" + kPrintEnd};
const std::string kSkippedCompilation{kPrintWarning + "\\" + kPrintEnd};
const std::string kUnsupportedPrecision{kPrintWarning + "o" + kPrintEnd};
+ const std::string kUnsupportedReference{kPrintWarning + "." + kPrintEnd};
// This structure combines the above log-entry with a status code an error percentage
struct ErrorLogEntry {
@@ -102,6 +103,10 @@ class Tester {
// Retrieves the offset values to test with
const std::vector<size_t> GetOffsets() const;
+ // Testing against reference implementations
+ int compare_cblas_;
+ int compare_clblas_;
+
private:
// Internal methods to report a passed, skipped, or failed test