diff options
Diffstat (limited to 'test/correctness/testblas.cc')
-rw-r--r-- | test/correctness/testblas.cc | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/test/correctness/testblas.cc b/test/correctness/testblas.cc index 6bcba267..1f83c59b 100644 --- a/test/correctness/testblas.cc +++ b/test/correctness/testblas.cc @@ -33,17 +33,22 @@ template <> const std::vector<Transpose> TestBlas<double2,double>::kTransposes = template <typename T, typename U> TestBlas<T,U>::TestBlas(int argc, char *argv[], const bool silent, const std::string &name, const std::vector<std::string> &options, - const Routine run_routine, const Routine run_reference, + const Routine run_routine, + const Routine run_reference1, const Routine run_reference2, const ResultGet get_result, const ResultIndex get_index, const ResultIterator get_id1, const ResultIterator get_id2): Tester<T,U>(argc, argv, silent, name, options), run_routine_(run_routine), - run_reference_(run_reference), get_result_(get_result), get_index_(get_index), get_id1_(get_id1), get_id2_(get_id2) { + // Sets the reference to test against + if (compare_clblas_) { run_reference_ = run_reference1; } + else if (compare_cblas_) { run_reference_ = run_reference2; } + else { throw std::runtime_error("Invalid configuration: no reference to test against"); } + // Computes the maximum sizes. This allows for a single set of input/output buffers. auto max_vec = *std::max_element(kVectorDims.begin(), kVectorDims.end()); auto max_inc = *std::max_element(kIncrements.begin(), kIncrements.end()); @@ -98,14 +103,11 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st auto buffers2 = Buffers<T>{x_vec2, y_vec2, a_mat2, b_mat2, c_mat2, ap_mat2, scalar2}; auto status2 = run_routine_(args, buffers2, queue_); - #ifndef CLBLAST_REF_CLBLAS - // Don't continue with CBLAS if there are incorrect parameters - if (status2 != StatusCode::kSuccess) { - // TODO: Mark this as a skipped test instead of a succesfull test - TestErrorCodes(status2, status2, args); - continue; - } - #endif + // Don't continue with CBLAS if there are incorrect parameters + if (compare_cblas_ && status2 != StatusCode::kSuccess) { + TestErrorCodes(status2, status2, args); + continue; + } // Runs the reference BLAS code auto x_vec1 = Buffer<T>(context_, args.x_size); @@ -168,6 +170,7 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st template <typename T, typename U> void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const std::string &name) { if (!PrecisionSupported<T>(device_)) { return; } + if (!compare_clblas_) { return; } TestStart("invalid buffer sizes", name); // Iterates over all the to-be-tested combinations of arguments |