diff options
-rw-r--r-- | test/correctness/testblas.hpp | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 1d1d2ca9..0dc8584e 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -61,6 +61,8 @@ class TestBlas: public Tester<T,U> { static const std::vector<size_t> kKernelSizes; static const std::vector<size_t> kBatchCounts; static const std::vector<size_t> kNumKernels; + static const std::vector<size_t> kStrideValues; + static const std::vector<size_t> kChannelValues; const std::vector<size_t> kOffsets; const std::vector<U> kAlphaValues; const std::vector<U> kBetaValues; @@ -138,6 +140,8 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadS template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kNumKernels = { 1, 2 }; +template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kStrideValues = { 1, 3 }; +template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kChannelValues = { 1, 4 }; // Test settings for the invalid tests template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kInvalidIncrements = { 0, 1 }; @@ -287,15 +291,15 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; } if (option == kArgAlpha) { alphas = tester.kAlphaValues; } if (option == kArgBeta) { betas = tester.kBetaValues; } - if (option == kArgChannels) { channelss = tester.kKernelSizes; } + if (option == kArgChannels) { channelss = tester.kChannelValues; } if (option == kArgHeight) { heights = tester.kMatrixDims; } if (option == kArgWidth) { widths = tester.kMatrixDims; } if (option == kArgKernelH) { kernel_hs = tester.kKernelSizes; } if (option == kArgKernelW) { kernel_ws = tester.kKernelSizes; } if (option == kArgPadH) { pad_hs = tester.kPadSizes; } if (option == kArgPadW) { pad_ws = tester.kPadSizes; } - if (option == kArgStrideH) { stride_hs = tester.kKernelSizes; } - if (option == kArgStrideW) { stride_ws = tester.kKernelSizes; } + if (option == kArgStrideH) { stride_hs = tester.kStrideValues; } + if (option == kArgStrideW) { stride_ws = tester.kStrideValues; } if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; } if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; } if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } |