summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-12-22 11:40:19 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-12-22 11:40:19 +0100
commit1f41c3c50abea269390cd65a033186b09da9e454 (patch)
treefd865e5f9cf516143df05483f25cc0ab302f9306
parent0c9411c84465d14d2de33046536403648909eb9f (diff)
parent9819957768174dbb4929b970718a0d6018520979 (diff)
Merge branch 'master' into convolution-fixes-and-tuner
-rw-r--r--test/correctness/tester.cpp1
-rw-r--r--test/performance/client.cpp2
-rw-r--r--test/routines/levelx/xcol2im.hpp1
-rw-r--r--test/routines/levelx/xconvgemm.hpp4
-rw-r--r--test/routines/levelx/xim2col.hpp1
5 files changed, 8 insertions, 1 deletions
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp
index daa43f26..df46167f 100644
--- a/test/correctness/tester.cpp
+++ b/test/correctness/tester.cpp
@@ -367,6 +367,7 @@ std::string Tester<T,U>::GetOptionsString(const Arguments<U> &args) {
if (o == kArgAlpha) { result += kArgAlpha + equals + ToString(args.alpha) + " "; }
if (o == kArgBeta) { result += kArgBeta + equals + ToString(args.beta) + " "; }
if (o == kArgBatchCount){result += kArgBatchCount + equals + ToString(args.batch_count) + " "; }
+ if (o == kArgKernelMode){result += kArgKernelMode + equals + ToString(args.kernel_mode) + " "; }
if (o == kArgChannels) { result += kArgChannels + equals + ToString(args.channels) + " "; }
if (o == kArgHeight) { result += kArgHeight + equals + ToString(args.height) + " "; }
if (o == kArgWidth) { result += kArgWidth + equals + ToString(args.width) + " "; }
diff --git a/test/performance/client.cpp b/test/performance/client.cpp
index 377e0140..34891429 100644
--- a/test/performance/client.cpp
+++ b/test/performance/client.cpp
@@ -107,6 +107,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<U>()); }
// Arguments for im2col and convgemm
+ if (o == kArgKernelMode){ args.kernel_mode = GetArgument(command_line_args, help, kArgKernelMode, KernelMode::kConvolution); }
if (o == kArgChannels) { args.channels = GetArgument(command_line_args, help, kArgChannels, size_t{64}); }
if (o == kArgHeight) { args.height = GetArgument(command_line_args, help, kArgHeight, size_t{64}); }
if (o == kArgWidth) { args.width = GetArgument(command_line_args, help, kArgWidth, size_t{64}); }
@@ -436,6 +437,7 @@ void Client<T,U>::PrintTableRow(const Arguments<U>& args,
else if (o == kArgAsumOffset){integers.push_back(args.asum_offset); }
else if (o == kArgImaxOffset){integers.push_back(args.imax_offset); }
else if (o == kArgBatchCount){integers.push_back(args.batch_count); }
+ else if (o == kArgKernelMode){integers.push_back(static_cast<size_t>(args.kernel_mode)); }
else if (o == kArgChannels) {integers.push_back(args.channels); }
else if (o == kArgHeight) {integers.push_back(args.height); }
else if (o == kArgWidth) {integers.push_back(args.width); }
diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp
index c28727e7..c740e4c7 100644
--- a/test/routines/levelx/xcol2im.hpp
+++ b/test/routines/levelx/xcol2im.hpp
@@ -204,6 +204,7 @@ StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &bu
auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy};
auto args2 = Arguments<float>();
args2.a_size = args.a_size; args2.b_size = args.b_size;
+ args2.kernel_mode = args.kernel_mode;
args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp
index e67b8174..786bb733 100644
--- a/test/routines/levelx/xconvgemm.hpp
+++ b/test/routines/levelx/xconvgemm.hpp
@@ -31,7 +31,8 @@ public:
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
- return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
+ return {kArgKernelMode,
+ kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, kArgNumKernels, kArgBatchCount,
kArgAOffset, kArgBOffset, kArgCOffset};
}
@@ -232,6 +233,7 @@ StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &bu
auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, c_buffer2, dummy, dummy};
auto args2 = Arguments<float>();
args2.a_size = args.a_size; args2.b_size = args.b_size; args2.c_size = args.c_size;
+ args2.kernel_mode = args.kernel_mode;
args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp
index 2a3577c3..2a5ebf8e 100644
--- a/test/routines/levelx/xim2col.hpp
+++ b/test/routines/levelx/xim2col.hpp
@@ -203,6 +203,7 @@ StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &bu
auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy};
auto args2 = Arguments<float>();
args2.a_size = args.a_size; args2.b_size = args.b_size;
+ args2.kernel_mode = args.kernel_mode;
args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;