From d9db543d75605fb02873e48197572450447481e1 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 17 Dec 2018 21:57:35 +0900 Subject: Fix half-float+kernel_mode test cases of im2col, col2im, and convgemm --- test/correctness/tester.cpp | 1 + test/performance/client.cpp | 2 ++ test/routines/levelx/xcol2im.hpp | 1 + test/routines/levelx/xconvgemm.hpp | 4 +++- test/routines/levelx/xim2col.hpp | 1 + 5 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index daa43f26..df46167f 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -367,6 +367,7 @@ std::string Tester::GetOptionsString(const Arguments &args) { if (o == kArgAlpha) { result += kArgAlpha + equals + ToString(args.alpha) + " "; } if (o == kArgBeta) { result += kArgBeta + equals + ToString(args.beta) + " "; } if (o == kArgBatchCount){result += kArgBatchCount + equals + ToString(args.batch_count) + " "; } + if (o == kArgKernelMode){result += kArgKernelMode + equals + ToString(args.kernel_mode) + " "; } if (o == kArgChannels) { result += kArgChannels + equals + ToString(args.channels) + " "; } if (o == kArgHeight) { result += kArgHeight + equals + ToString(args.height) + " "; } if (o == kArgWidth) { result += kArgWidth + equals + ToString(args.width) + " "; } diff --git a/test/performance/client.cpp b/test/performance/client.cpp index 377e0140..34891429 100644 --- a/test/performance/client.cpp +++ b/test/performance/client.cpp @@ -107,6 +107,7 @@ Arguments Client::ParseArguments(int argc, char *argv[], const size_t le if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar()); } // Arguments for im2col and convgemm + if (o == kArgKernelMode){ args.kernel_mode = GetArgument(command_line_args, help, kArgKernelMode, KernelMode::kConvolution); } if (o == kArgChannels) { args.channels = GetArgument(command_line_args, help, kArgChannels, size_t{64}); } if (o == kArgHeight) { args.height = GetArgument(command_line_args, help, kArgHeight, size_t{64}); } if (o == kArgWidth) { args.width = GetArgument(command_line_args, help, kArgWidth, size_t{64}); } @@ -436,6 +437,7 @@ void Client::PrintTableRow(const Arguments& args, else if (o == kArgAsumOffset){integers.push_back(args.asum_offset); } else if (o == kArgImaxOffset){integers.push_back(args.imax_offset); } else if (o == kArgBatchCount){integers.push_back(args.batch_count); } + else if (o == kArgKernelMode){integers.push_back(static_cast(args.kernel_mode)); } else if (o == kArgChannels) {integers.push_back(args.channels); } else if (o == kArgHeight) {integers.push_back(args.height); } else if (o == kArgWidth) {integers.push_back(args.width); } diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp index c28727e7..c740e4c7 100644 --- a/test/routines/levelx/xcol2im.hpp +++ b/test/routines/levelx/xcol2im.hpp @@ -204,6 +204,7 @@ StatusCode RunReference(const Arguments &args, BuffersHost &bu auto buffers2 = BuffersHost{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy}; auto args2 = Arguments(); args2.a_size = args.a_size; args2.b_size = args.b_size; + args2.kernel_mode = args.kernel_mode; args2.channels = args.channels; args2.height = args.height; args2.width = args.width; args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w; args2.pad_h = args.pad_h; args2.pad_w = args.pad_w; diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp index e67b8174..786bb733 100644 --- a/test/routines/levelx/xconvgemm.hpp +++ b/test/routines/levelx/xconvgemm.hpp @@ -31,7 +31,8 @@ public: // The list of arguments relevant for this routine static std::vector GetOptions() { - return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, + return {kArgKernelMode, + kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, kArgNumKernels, kArgBatchCount, kArgAOffset, kArgBOffset, kArgCOffset}; } @@ -232,6 +233,7 @@ StatusCode RunReference(const Arguments &args, BuffersHost &bu auto buffers2 = BuffersHost{dummy, dummy, a_buffer2, b_buffer2, c_buffer2, dummy, dummy}; auto args2 = Arguments(); args2.a_size = args.a_size; args2.b_size = args.b_size; args2.c_size = args.c_size; + args2.kernel_mode = args.kernel_mode; args2.channels = args.channels; args2.height = args.height; args2.width = args.width; args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w; args2.pad_h = args.pad_h; args2.pad_w = args.pad_w; diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp index 2a3577c3..2a5ebf8e 100644 --- a/test/routines/levelx/xim2col.hpp +++ b/test/routines/levelx/xim2col.hpp @@ -203,6 +203,7 @@ StatusCode RunReference(const Arguments &args, BuffersHost &bu auto buffers2 = BuffersHost{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy}; auto args2 = Arguments(); args2.a_size = args.a_size; args2.b_size = args.b_size; + args2.kernel_mode = args.kernel_mode; args2.channels = args.channels; args2.height = args.height; args2.width = args.width; args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w; args2.pad_h = args.pad_h; args2.pad_w = args.pad_w; -- cgit v1.2.3