summaryrefslogtreecommitdiff
path: root/test/routines/levelx/xconvgemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/routines/levelx/xconvgemm.hpp')
-rw-r--r--test/routines/levelx/xconvgemm.hpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp
index 7fa4e701..e67b8174 100644
--- a/test/routines/levelx/xconvgemm.hpp
+++ b/test/routines/levelx/xconvgemm.hpp
@@ -91,7 +91,8 @@ public:
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
- auto status = Convgemm<T>(args.channels, args.height, args.width,
+ auto status = Convgemm<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -103,7 +104,8 @@ public:
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
- auto status = Convgemm<T>(args.channels, args.height, args.width,
+ auto status = Convgemm<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -189,10 +191,16 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
const auto input_value = buffers_host.a_mat[input_index + args.a_offset];
// Multiplies with the kernel tensor
- const auto kernel_index = kw_id + args.kernel_w * (
- kh_id + args.kernel_h * (
- ci_id + args.channels * (
- co_id)));
+ const auto kernel_index
+ = (args.kernel_mode == KernelMode::kConvolution)
+ ? (args.kernel_w - kw_id - 1) + args.kernel_w * (
+ (args.kernel_h - kh_id - 1) + args.kernel_h * (
+ ci_id + args.channels * (
+ co_id)))
+ : kw_id + args.kernel_w * (
+ kh_id + args.kernel_h * (
+ ci_id + args.channels * (
+ co_id)));
const auto kernel_value = buffers_host.b_mat[kernel_index + args.b_offset];
result += input_value * kernel_value;