diff options
Diffstat (limited to 'test/correctness/testblas.hpp')
-rw-r--r-- | test/correctness/testblas.hpp | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 54b2d6f8..1d1d2ca9 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -60,6 +60,7 @@ class TestBlas: public Tester<T,U> { static const std::vector<size_t> kDilationSizes; static const std::vector<size_t> kKernelSizes; static const std::vector<size_t> kBatchCounts; + static const std::vector<size_t> kNumKernels; const std::vector<size_t> kOffsets; const std::vector<U> kAlphaValues; const std::vector<U> kBetaValues; @@ -136,6 +137,7 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBatc template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0, 1 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 }; +template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kNumKernels = { 1, 2 }; // Test settings for the invalid tests template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kInvalidIncrements = { 0, 1 }; @@ -241,6 +243,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na auto dilation_hs = std::vector<size_t>{args.dilation_h}; auto dilation_ws = std::vector<size_t>{args.dilation_w}; auto batch_counts = std::vector<size_t>{args.batch_count}; + auto num_kernelss = std::vector<size_t>{args.num_kernels}; auto x_sizes = std::vector<size_t>{args.x_size}; auto y_sizes = std::vector<size_t>{args.y_size}; auto a_sizes = std::vector<size_t>{args.a_size}; @@ -296,6 +299,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; } if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; } if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } + if (option == kArgNumKernels) { num_kernelss = tester.kNumKernels; } if (option == kArgXOffset) { x_sizes = tester.kVecSizes; } if (option == kArgYOffset) { y_sizes = tester.kVecSizes; } @@ -350,8 +354,10 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h; for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w; for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count; - C::SetSizes(r_args, tester.queue_); - regular_test_vector.push_back(r_args); + for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels; + C::SetSizes(r_args, tester.queue_); + regular_test_vector.push_back(r_args); + } } } } |