summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-05-09 19:59:31 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-05-09 19:59:31 +0200
commitb60828036122c5fe6e0305963ddc1ada6a2effff (patch)
tree58d6b6e6572cc7fe9442d9949a0ac61ba3d8a0bc /test
parenta4119531eedd5220c9f02c8e1a8a5c3376367049 (diff)
Fixed the performance client for convgemm and added GFLOPS measurements
Diffstat (limited to 'test')
-rw-r--r--test/performance/client.cpp4
-rw-r--r--test/routines/levelx/xconvgemm.hpp4
2 files changed, 6 insertions, 2 deletions
diff --git a/test/performance/client.cpp b/test/performance/client.cpp
index 9480d11a..48690c3d 100644
--- a/test/performance/client.cpp
+++ b/test/performance/client.cpp
@@ -105,7 +105,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le
if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<U>()); }
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<U>()); }
- // Arguments for im2col
+ // Arguments for im2col and convgemm
if (o == kArgChannels) { args.channels = GetArgument(command_line_args, help, kArgChannels, size_t{64}); }
if (o == kArgHeight) { args.height = GetArgument(command_line_args, help, kArgHeight, size_t{64}); }
if (o == kArgWidth) { args.width = GetArgument(command_line_args, help, kArgWidth, size_t{64}); }
@@ -117,6 +117,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le
if (o == kArgStrideW) { args.stride_w = GetArgument(command_line_args, help, kArgStrideW, size_t{1}); }
if (o == kArgDilationH) { args.dilation_h = GetArgument(command_line_args, help, kArgDilationH, size_t{1}); }
if (o == kArgDilationW) { args.dilation_w = GetArgument(command_line_args, help, kArgDilationW, size_t{1}); }
+ if (o == kArgNumKernels){ args.num_kernels = GetArgument(command_line_args, help, kArgNumKernels, size_t{1}); }
}
// These are the options common to all routines
@@ -416,6 +417,7 @@ void Client<T,U>::PrintTableRow(const Arguments<U>& args,
else if (o == kArgStrideW) {integers.push_back(args.stride_w); }
else if (o == kArgDilationH) {integers.push_back(args.dilation_h); }
else if (o == kArgDilationW) {integers.push_back(args.dilation_w); }
+ else if (o == kArgNumKernels){integers.push_back(args.num_kernels); }
}
auto strings = std::vector<std::string>{};
for (auto &o: options_) {
diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp
index 6ca5965b..7233f7b6 100644
--- a/test/routines/levelx/xconvgemm.hpp
+++ b/test/routines/levelx/xconvgemm.hpp
@@ -151,7 +151,9 @@ public:
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
- return args.batch_count; // TODO
+ const auto patch_size = args.kernel_h * args.kernel_w * args.channels;
+ const auto num_patches = OutputHeight(args) * OutputWidth(args);
+ return args.batch_count * 2 * num_patches * args.num_kernels * patch_size;
}
static size_t GetBytes(const Arguments<T> &args) {
return (GetSizeA(args) + GetSizeB(args) + GetSizeC(args)) * sizeof(T);