diff options
-rw-r--r-- | src/utilities/utilities.hpp | 8 | ||||
-rw-r--r-- | test/performance/client.cpp | 26 | ||||
-rw-r--r-- | test/routines/levelx/xim2col.hpp | 4 |
3 files changed, 32 insertions, 6 deletions
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index 784e0324..ad6edf3a 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -171,10 +171,10 @@ struct Arguments { size_t channels = 1; size_t height = 1; size_t width = 1; - size_t kernel_h = 1; - size_t kernel_w = 1; - size_t pad_h = 1; - size_t pad_w = 1; + size_t kernel_h = 3; + size_t kernel_w = 3; + size_t pad_h = 0; + size_t pad_w = 0; size_t stride_h = 1; size_t stride_w = 1; size_t dilation_h = 1; diff --git a/test/performance/client.cpp b/test/performance/client.cpp index dc98ffbd..076481f7 100644 --- a/test/performance/client.cpp +++ b/test/performance/client.cpp @@ -93,7 +93,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le if (o == kArgAPOffset) { args.ap_offset= GetArgument(command_line_args, help, kArgAPOffset, size_t{0}); } // Scalar result arguments - if (o == kArgDotOffset) { args.dot_offset = GetArgument(command_line_args, help, kArgDotOffset, size_t{0}); } + if (o == kArgDotOffset) { args.dot_offset = GetArgument(command_line_args, help, kArgDotOffset, size_t{0}); } if (o == kArgNrm2Offset) { args.nrm2_offset = GetArgument(command_line_args, help, kArgNrm2Offset, size_t{0}); } if (o == kArgAsumOffset) { args.asum_offset = GetArgument(command_line_args, help, kArgAsumOffset, size_t{0}); } if (o == kArgImaxOffset) { args.imax_offset = GetArgument(command_line_args, help, kArgImaxOffset, size_t{0}); } @@ -104,6 +104,19 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le // Scalar values if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<U>()); } if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<U>()); } + + // Arguments for im2col + if (o == kArgChannels) { args.channels = GetArgument(command_line_args, help, kArgChannels, size_t{64}); } + if (o == kArgHeight) { args.height = GetArgument(command_line_args, help, kArgHeight, size_t{64}); } + if (o == kArgWidth) { args.width = GetArgument(command_line_args, help, kArgWidth, size_t{64}); } + if (o == kArgKernelH) { args.kernel_h = GetArgument(command_line_args, help, kArgKernelH, size_t{3}); } + if (o == kArgKernelW) { args.kernel_w = GetArgument(command_line_args, help, kArgKernelW, size_t{3}); } + if (o == kArgPadH) { args.pad_h = GetArgument(command_line_args, help, kArgPadH, size_t{0}); } + if (o == kArgPadW) { args.pad_w = GetArgument(command_line_args, help, kArgPadW, size_t{0}); } + if (o == kArgStrideH) { args.stride_h = GetArgument(command_line_args, help, kArgStrideH, size_t{1}); } + if (o == kArgStrideW) { args.stride_w = GetArgument(command_line_args, help, kArgStrideW, size_t{1}); } + if (o == kArgDilationH) { args.dilation_h = GetArgument(command_line_args, help, kArgDilationH, size_t{1}); } + if (o == kArgDilationW) { args.dilation_w = GetArgument(command_line_args, help, kArgDilationW, size_t{1}); } } // These are the options common to all routines @@ -379,6 +392,17 @@ void Client<T,U>::PrintTableRow(const Arguments<U>& args, else if (o == kArgAsumOffset){integers.push_back(args.asum_offset); } else if (o == kArgImaxOffset){integers.push_back(args.imax_offset); } else if (o == kArgBatchCount){integers.push_back(args.batch_count); } + else if (o == kArgChannels) {integers.push_back(args.channels); } + else if (o == kArgHeight) {integers.push_back(args.height); } + else if (o == kArgWidth) {integers.push_back(args.width); } + else if (o == kArgKernelH) {integers.push_back(args.kernel_h); } + else if (o == kArgKernelW) {integers.push_back(args.kernel_w); } + else if (o == kArgPadH) {integers.push_back(args.pad_h); } + else if (o == kArgPadW) {integers.push_back(args.pad_w); } + else if (o == kArgStrideH) {integers.push_back(args.stride_h); } + else if (o == kArgStrideW) {integers.push_back(args.stride_w); } + else if (o == kArgDilationH) {integers.push_back(args.dilation_h); } + else if (o == kArgDilationW) {integers.push_back(args.dilation_w); } } auto strings = std::vector<std::string>{}; for (auto &o: options_) { diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp index e6aefd9e..59be8156 100644 --- a/test/routines/levelx/xim2col.hpp +++ b/test/routines/levelx/xim2col.hpp @@ -134,7 +134,9 @@ public: return 1; } static size_t GetBytes(const Arguments<T> &args) { - return (1) * sizeof(T); + const auto input = args.channels * args.width * args.height; // possibly less with striding + const auto output = args.kernel_h * args.kernel_w * NumPatches(args); + return (input + output) * sizeof(T); } }; |