summaryrefslogtreecommitdiff
path: root/test/routines/levelx/xcol2im.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/routines/levelx/xcol2im.hpp')
-rw-r--r--test/routines/levelx/xcol2im.hpp14
1 files changed, 10 insertions, 4 deletions
diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp
index 176fceae..c28727e7 100644
--- a/test/routines/levelx/xcol2im.hpp
+++ b/test/routines/levelx/xcol2im.hpp
@@ -31,7 +31,8 @@ public:
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
- return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
+ return {kArgKernelMode,
+ kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW,
kArgAOffset, kArgBOffset};
}
@@ -87,7 +88,8 @@ public:
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
- auto status = Col2im<T>(args.channels, args.height, args.width,
+ auto status = Col2im<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -97,7 +99,8 @@ public:
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
- auto status = Col2im<T>(args.channels, args.height, args.width,
+ auto status = Col2im<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -167,7 +170,10 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
for (auto w_id = size_t{0}; w_id < col_w; ++w_id) { // image width
// Reads the input value
- const auto kernel_index = kw_id + args.kernel_w * kh_id;
+ const auto kernel_index
+ = (args.kernel_mode == KernelMode::kConvolution)
+ ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1
+ : kw_id + args.kernel_w * kh_id;
const auto patch_index = w_id + col_w * h_id;
const auto col_index = patch_index + kernel_index * col_w * col_h +
c_id * col_w * col_h * args.kernel_h * args.kernel_w;