diff options
Diffstat (limited to 'test/correctness/testblas.hpp')
-rw-r--r-- | test/correctness/testblas.hpp | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index ee795aad..e675fa9b 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -56,6 +56,7 @@ class TestBlas: public Tester<T,U> { static const std::vector<size_t> kMatrixDims; static const std::vector<size_t> kMatrixVectorDims; static const std::vector<size_t> kBandSizes; + static const std::vector<size_t> kBatchCounts; const std::vector<size_t> kOffsets; const std::vector<U> kAlphaValues; const std::vector<U> kBetaValues; @@ -78,7 +79,7 @@ class TestBlas: public Tester<T,U> { std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&)>; - using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>; + using Routine = std::function<StatusCode(const Arguments<U>&, std::vector<Buffers<T>>&, Queue&)>; using ResultGet = std::function<std::vector<T>(const Arguments<U>&, Buffers<T>&, Queue&)>; using ResultIndex = std::function<size_t(const Arguments<U>&, const size_t, const size_t)>; using ResultIterator = std::function<size_t(const Arguments<U>&)>; @@ -183,6 +184,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na auto imax_offsets = std::vector<size_t>{args.imax_offset}; auto alphas = std::vector<U>{args.alpha}; auto betas = std::vector<U>{args.beta}; + auto batch_counts = std::vector<size_t>{args.batch_count}; auto x_sizes = std::vector<size_t>{args.x_size}; auto y_sizes = std::vector<size_t>{args.y_size}; auto a_sizes = std::vector<size_t>{args.a_size}; @@ -226,6 +228,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; } if (option == kArgAlpha) { alphas = tester.kAlphaValues; } if (option == kArgBeta) { betas = tester.kBetaValues; } + if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } if (option == kArgXOffset) { x_sizes = tester.kVecSizes; } if (option == kArgYOffset) { y_sizes = tester.kVecSizes; } @@ -268,8 +271,10 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na for (auto &imax_offset: imax_offsets) { r_args.imax_offset = imax_offset; for (auto &alpha: alphas) { r_args.alpha = alpha; for (auto &beta: betas) { r_args.beta = beta; - C::SetSizes(r_args); - regular_test_vector.push_back(r_args); + for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count; + C::SetSizes(r_args); + regular_test_vector.push_back(r_args); + } } } } |