summaryrefslogtreecommitdiff
path: root/test/correctness/testblas.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/correctness/testblas.hpp')
-rw-r--r--test/correctness/testblas.hpp11
1 files changed, 8 insertions, 3 deletions
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index ee795aad..e675fa9b 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -56,6 +56,7 @@ class TestBlas: public Tester<T,U> {
static const std::vector<size_t> kMatrixDims;
static const std::vector<size_t> kMatrixVectorDims;
static const std::vector<size_t> kBandSizes;
+ static const std::vector<size_t> kBatchCounts;
const std::vector<size_t> kOffsets;
const std::vector<U> kAlphaValues;
const std::vector<U> kBetaValues;
@@ -78,7 +79,7 @@ class TestBlas: public Tester<T,U> {
std::vector<T>&, std::vector<T>&,
std::vector<T>&, std::vector<T>&, std::vector<T>&,
std::vector<T>&, std::vector<T>&)>;
- using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
+ using Routine = std::function<StatusCode(const Arguments<U>&, std::vector<Buffers<T>>&, Queue&)>;
using ResultGet = std::function<std::vector<T>(const Arguments<U>&, Buffers<T>&, Queue&)>;
using ResultIndex = std::function<size_t(const Arguments<U>&, const size_t, const size_t)>;
using ResultIterator = std::function<size_t(const Arguments<U>&)>;
@@ -183,6 +184,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
auto imax_offsets = std::vector<size_t>{args.imax_offset};
auto alphas = std::vector<U>{args.alpha};
auto betas = std::vector<U>{args.beta};
+ auto batch_counts = std::vector<size_t>{args.batch_count};
auto x_sizes = std::vector<size_t>{args.x_size};
auto y_sizes = std::vector<size_t>{args.y_size};
auto a_sizes = std::vector<size_t>{args.a_size};
@@ -226,6 +228,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; }
if (option == kArgAlpha) { alphas = tester.kAlphaValues; }
if (option == kArgBeta) { betas = tester.kBetaValues; }
+ if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; }
if (option == kArgXOffset) { x_sizes = tester.kVecSizes; }
if (option == kArgYOffset) { y_sizes = tester.kVecSizes; }
@@ -268,8 +271,10 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
for (auto &imax_offset: imax_offsets) { r_args.imax_offset = imax_offset;
for (auto &alpha: alphas) { r_args.alpha = alpha;
for (auto &beta: betas) { r_args.beta = beta;
- C::SetSizes(r_args);
- regular_test_vector.push_back(r_args);
+ for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count;
+ C::SetSizes(r_args);
+ regular_test_vector.push_back(r_args);
+ }
}
}
}