diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-08-19 16:55:09 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-08-19 16:55:09 +0200 |
commit | 132e62892de91c1dec2ffe1123a106bba0ffd822 (patch) | |
tree | af33c059b564ce1f90f197c0d6f834b1d8b2f404 /test/correctness | |
parent | 777681dcbdf18493320dd7b94fccd5c6faee9455 (diff) |
Implemented proper im2col reference function and completd tests
Diffstat (limited to 'test/correctness')
-rw-r--r-- | test/correctness/testblas.hpp | 10 | ||||
-rw-r--r-- | test/correctness/tester.cpp | 6 |
2 files changed, 12 insertions, 4 deletions
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 1c0cf9e3..4e02fd28 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -57,6 +57,7 @@ class TestBlas: public Tester<T,U> { static const std::vector<size_t> kMatrixVectorDims; static const std::vector<size_t> kBandSizes; static const std::vector<size_t> kPadSizes; + static const std::vector<size_t> kDilationSizes; static const std::vector<size_t> kKernelSizes; static const std::vector<size_t> kBatchCounts; const std::vector<size_t> kOffsets; @@ -132,7 +133,8 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kMatr template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kMatrixVectorDims = { 61, 256 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBandSizes = { 4, 19 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBatchCounts = { 1, 3 }; -template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0 }; +template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0, 1 }; +template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 }; // Test settings for the invalid tests @@ -282,7 +284,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; } if (option == kArgAlpha) { alphas = tester.kAlphaValues; } if (option == kArgBeta) { betas = tester.kBetaValues; } - if (option == kArgChannels) { channelss = tester.kMatrixDims; } + if (option == kArgChannels) { channelss = tester.kKernelSizes; } if (option == kArgHeight) { heights = tester.kMatrixDims; } if (option == kArgWidth) { widths = tester.kMatrixDims; } if (option == kArgKernelH) { kernel_hs = tester.kKernelSizes; } @@ -291,8 +293,8 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgPadW) { pad_ws = tester.kPadSizes; } if (option == kArgStrideH) { stride_hs = tester.kKernelSizes; } if (option == kArgStrideW) { stride_ws = tester.kKernelSizes; } - if (option == kArgDilationH) { dilation_hs = tester.kKernelSizes; } - if (option == kArgDilationW) { dilation_ws = tester.kKernelSizes; } + if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; } + if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; } if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } if (option == kArgXOffset) { x_sizes = tester.kVecSizes; } diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index 648aef6e..9dbd8934 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -371,6 +371,12 @@ std::string Tester<T,U>::GetOptionsString(const Arguments<U> &args) { if (o == kArgWidth) { result += kArgWidth + equals + ToString(args.width) + " "; } if (o == kArgKernelH) { result += kArgKernelH + equals + ToString(args.kernel_h) + " "; } if (o == kArgKernelW) { result += kArgKernelW + equals + ToString(args.kernel_w) + " "; } + if (o == kArgPadH) { result += kArgPadH + equals + ToString(args.pad_h) + " "; } + if (o == kArgPadW) { result += kArgPadW + equals + ToString(args.pad_w) + " "; } + if (o == kArgStrideH) { result += kArgStrideH + equals + ToString(args.stride_h) + " "; } + if (o == kArgStrideW) { result += kArgStrideW + equals + ToString(args.stride_w) + " "; } + if (o == kArgDilationH){ result += kArgDilationH + equals + ToString(args.dilation_h) + " "; } + if (o == kArgDilationW){ result += kArgDilationW + equals + ToString(args.dilation_w) + " "; } } return result; } |