summaryrefslogtreecommitdiff
path: root/test/correctness
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-04 15:21:33 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-03-04 15:21:33 +0100
commite993ee077b50d3a6134309d465a4174b5c749596 (patch)
treeb967f2702b90d8080a3e3cb41b9cbc01ab9eddc3 /test/correctness
parent3fc73851f7ed885335940eb85e53069638567323 (diff)
Added a proper data-preparation function for the TRSM tests
Diffstat (limited to 'test/correctness')
-rw-r--r--test/correctness/testblas.cpp7
-rw-r--r--test/correctness/testblas.hpp8
-rw-r--r--test/correctness/tester.cpp2
3 files changed, 16 insertions, 1 deletions
diff --git a/test/correctness/testblas.cpp b/test/correctness/testblas.cpp
index 5207c0ab..d959ce18 100644
--- a/test/correctness/testblas.cpp
+++ b/test/correctness/testblas.cpp
@@ -51,6 +51,7 @@ template <> const std::vector<Transpose> TestBlas<double2,double>::kTransposes =
template <typename T, typename U>
TestBlas<T,U>::TestBlas(const std::vector<std::string> &arguments, const bool silent,
const std::string &name, const std::vector<std::string> &options,
+ const DataPrepare prepare_data,
const Routine run_routine,
const Routine run_reference1, const Routine run_reference2,
const ResultGet get_result, const ResultIndex get_index,
@@ -59,6 +60,7 @@ TestBlas<T,U>::TestBlas(const std::vector<std::string> &arguments, const bool si
kOffsets(GetOffsets()),
kAlphaValues(GetExampleScalars<U>(full_test_)),
kBetaValues(GetExampleScalars<U>(full_test_)),
+ prepare_data_(prepare_data),
run_routine_(run_routine),
get_result_(get_result),
get_index_(get_index),
@@ -112,6 +114,11 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st
std::cout << std::flush;
}
+ // Optionally prepares the input data
+ prepare_data_(args, queue_, kSeed,
+ x_source_, y_source_, a_source_, b_source_, c_source_,
+ ap_source_, scalar_source_);
+
// Set-up for the CLBlast run
auto x_vec2 = Buffer<T>(context_, args.x_size);
auto y_vec2 = Buffer<T>(context_, args.y_size);
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index 27fd84c3..ee795aad 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -74,6 +74,10 @@ class TestBlas: public Tester<T,U> {
static const std::vector<Transpose> kTransposes; // Data-type dependent, see .cc-file
// Shorthand for the routine-specific functions passed to the tester
+ using DataPrepare = std::function<void(const Arguments<U>&, Queue&, const int,
+ std::vector<T>&, std::vector<T>&,
+ std::vector<T>&, std::vector<T>&, std::vector<T>&,
+ std::vector<T>&, std::vector<T>&)>;
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
using ResultGet = std::function<std::vector<T>(const Arguments<U>&, Buffers<T>&, Queue&)>;
using ResultIndex = std::function<size_t(const Arguments<U>&, const size_t, const size_t)>;
@@ -82,6 +86,7 @@ class TestBlas: public Tester<T,U> {
// Constructor, initializes the base class tester and input data
TestBlas(const std::vector<std::string> &arguments, const bool silent,
const std::string &name, const std::vector<std::string> &options,
+ const DataPrepare prepare_data,
const Routine run_routine,
const Routine run_reference1, const Routine run_reference2,
const ResultGet get_result, const ResultIndex get_index,
@@ -103,6 +108,7 @@ class TestBlas: public Tester<T,U> {
std::vector<T> scalar_source_;
// The routine-specific functions passed to the tester
+ DataPrepare prepare_data_;
Routine run_routine_;
Routine run_reference_;
ResultGet get_result_;
@@ -141,7 +147,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
// Creates a tester
auto options = C::GetOptions();
TestBlas<T,U> tester{command_line_args, silent, name, options,
- C::RunRoutine, reference_routine1, reference_routine2,
+ C::PrepareData, C::RunRoutine, reference_routine1, reference_routine2,
C::DownloadResult, C::GetResultIndex, C::ResultID1, C::ResultID2};
// This variable holds the arguments relevant for this routine
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp
index 046473f8..cbfc5bb2 100644
--- a/test/correctness/tester.cpp
+++ b/test/correctness/tester.cpp
@@ -365,6 +365,8 @@ std::string Tester<T,U>::GetOptionsString(const Arguments<U> &args) {
if (o == kArgCOffset) { result += kArgCOffset + equals + ToString(args.c_offset) + " "; }
if (o == kArgAPOffset) { result += kArgAPOffset + equals + ToString(args.ap_offset) + " "; }
if (o == kArgDotOffset){ result += kArgDotOffset + equals + ToString(args.dot_offset) + " "; }
+ if (o == kArgAlpha) { result += kArgAlpha + equals + ToString(args.alpha) + " "; }
+ if (o == kArgBeta) { result += kArgBeta + equals + ToString(args.beta) + " "; }
}
return result;
}