From b60828036122c5fe6e0305963ddc1ada6a2effff Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Wed, 9 May 2018 19:59:31 +0200 Subject: Fixed the performance client for convgemm and added GFLOPS measurements --- test/performance/client.cpp | 4 +++- test/routines/levelx/xconvgemm.hpp | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/performance/client.cpp b/test/performance/client.cpp index 9480d11a..48690c3d 100644 --- a/test/performance/client.cpp +++ b/test/performance/client.cpp @@ -105,7 +105,7 @@ Arguments Client::ParseArguments(int argc, char *argv[], const size_t le if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar()); } if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar()); } - // Arguments for im2col + // Arguments for im2col and convgemm if (o == kArgChannels) { args.channels = GetArgument(command_line_args, help, kArgChannels, size_t{64}); } if (o == kArgHeight) { args.height = GetArgument(command_line_args, help, kArgHeight, size_t{64}); } if (o == kArgWidth) { args.width = GetArgument(command_line_args, help, kArgWidth, size_t{64}); } @@ -117,6 +117,7 @@ Arguments Client::ParseArguments(int argc, char *argv[], const size_t le if (o == kArgStrideW) { args.stride_w = GetArgument(command_line_args, help, kArgStrideW, size_t{1}); } if (o == kArgDilationH) { args.dilation_h = GetArgument(command_line_args, help, kArgDilationH, size_t{1}); } if (o == kArgDilationW) { args.dilation_w = GetArgument(command_line_args, help, kArgDilationW, size_t{1}); } + if (o == kArgNumKernels){ args.num_kernels = GetArgument(command_line_args, help, kArgNumKernels, size_t{1}); } } // These are the options common to all routines @@ -416,6 +417,7 @@ void Client::PrintTableRow(const Arguments& args, else if (o == kArgStrideW) {integers.push_back(args.stride_w); } else if (o == kArgDilationH) {integers.push_back(args.dilation_h); } else if (o == kArgDilationW) {integers.push_back(args.dilation_w); } + else if (o == kArgNumKernels){integers.push_back(args.num_kernels); } } auto strings = std::vector{}; for (auto &o: options_) { diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp index 6ca5965b..7233f7b6 100644 --- a/test/routines/levelx/xconvgemm.hpp +++ b/test/routines/levelx/xconvgemm.hpp @@ -151,7 +151,9 @@ public: // Describes how to compute performance metrics static size_t GetFlops(const Arguments &args) { - return args.batch_count; // TODO + const auto patch_size = args.kernel_h * args.kernel_w * args.channels; + const auto num_patches = OutputHeight(args) * OutputWidth(args); + return args.batch_count * 2 * num_patches * args.num_kernels * patch_size; } static size_t GetBytes(const Arguments &args) { return (GetSizeA(args) + GetSizeB(args) + GetSizeC(args)) * sizeof(T); -- cgit v1.2.3