summaryrefslogtreecommitdiff
path: root/test/correctness/testblas.hpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-08-19 16:55:09 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-08-19 16:55:09 +0200
commit132e62892de91c1dec2ffe1123a106bba0ffd822 (patch)
treeaf33c059b564ce1f90f197c0d6f834b1d8b2f404 /test/correctness/testblas.hpp
parent777681dcbdf18493320dd7b94fccd5c6faee9455 (diff)
Implemented proper im2col reference function and completd tests
Diffstat (limited to 'test/correctness/testblas.hpp')
-rw-r--r--test/correctness/testblas.hpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index 1c0cf9e3..4e02fd28 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -57,6 +57,7 @@ class TestBlas: public Tester<T,U> {
static const std::vector<size_t> kMatrixVectorDims;
static const std::vector<size_t> kBandSizes;
static const std::vector<size_t> kPadSizes;
+ static const std::vector<size_t> kDilationSizes;
static const std::vector<size_t> kKernelSizes;
static const std::vector<size_t> kBatchCounts;
const std::vector<size_t> kOffsets;
@@ -132,7 +133,8 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kMatr
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kMatrixVectorDims = { 61, 256 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBandSizes = { 4, 19 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBatchCounts = { 1, 3 };
-template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0 };
+template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0, 1 };
+template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 };
// Test settings for the invalid tests
@@ -282,7 +284,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; }
if (option == kArgAlpha) { alphas = tester.kAlphaValues; }
if (option == kArgBeta) { betas = tester.kBetaValues; }
- if (option == kArgChannels) { channelss = tester.kMatrixDims; }
+ if (option == kArgChannels) { channelss = tester.kKernelSizes; }
if (option == kArgHeight) { heights = tester.kMatrixDims; }
if (option == kArgWidth) { widths = tester.kMatrixDims; }
if (option == kArgKernelH) { kernel_hs = tester.kKernelSizes; }
@@ -291,8 +293,8 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
if (option == kArgPadW) { pad_ws = tester.kPadSizes; }
if (option == kArgStrideH) { stride_hs = tester.kKernelSizes; }
if (option == kArgStrideW) { stride_ws = tester.kKernelSizes; }
- if (option == kArgDilationH) { dilation_hs = tester.kKernelSizes; }
- if (option == kArgDilationW) { dilation_ws = tester.kKernelSizes; }
+ if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; }
+ if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; }
if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; }
if (option == kArgXOffset) { x_sizes = tester.kVecSizes; }