diff options
author | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
---|---|---|
committer | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
commit | 032e3b0cc00a15dd2af8b4fb82d261eb7b086e26 (patch) | |
tree | cdcf4d0fc342c9ff92ee7ab3f75b0cdeced46e96 /test/routines/levelx | |
parent | 90112618daa0d6b24ae3e53203a636d2e908dfba (diff) |
Add kernel_mode option to im2col, col2im, and convgemm functions
Diffstat (limited to 'test/routines/levelx')
-rw-r--r-- | test/routines/levelx/xcol2im.hpp | 14 | ||||
-rw-r--r-- | test/routines/levelx/xconvgemm.hpp | 20 | ||||
-rw-r--r-- | test/routines/levelx/xim2col.hpp | 14 |
3 files changed, 34 insertions, 14 deletions
diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp index 176fceae..c28727e7 100644 --- a/test/routines/levelx/xcol2im.hpp +++ b/test/routines/levelx/xcol2im.hpp @@ -31,7 +31,8 @@ public: // The list of arguments relevant for this routine static std::vector<std::string> GetOptions() { - return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, + return {kArgKernelMode, + kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, kArgAOffset, kArgBOffset}; } @@ -87,7 +88,8 @@ public: #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; - auto status = Col2im<T>(args.channels, args.height, args.width, + auto status = Col2im<T>(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -97,7 +99,8 @@ public: &queue_plain, &event); if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API - auto status = Col2im<T>(args.channels, args.height, args.width, + auto status = Col2im<T>(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -167,7 +170,10 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) for (auto w_id = size_t{0}; w_id < col_w; ++w_id) { // image width // Reads the input value - const auto kernel_index = kw_id + args.kernel_w * kh_id; + const auto kernel_index + = (args.kernel_mode == KernelMode::kConvolution) + ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1 + : kw_id + args.kernel_w * kh_id; const auto patch_index = w_id + col_w * h_id; const auto col_index = patch_index + kernel_index * col_w * col_h + c_id * col_w * col_h * args.kernel_h * args.kernel_w; diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp index 7fa4e701..e67b8174 100644 --- a/test/routines/levelx/xconvgemm.hpp +++ b/test/routines/levelx/xconvgemm.hpp @@ -91,7 +91,8 @@ public: #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; - auto status = Convgemm<T>(args.channels, args.height, args.width, + auto status = Convgemm<T>(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -103,7 +104,8 @@ public: &queue_plain, &event); if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API - auto status = Convgemm<T>(args.channels, args.height, args.width, + auto status = Convgemm<T>(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -189,10 +191,16 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) const auto input_value = buffers_host.a_mat[input_index + args.a_offset]; // Multiplies with the kernel tensor - const auto kernel_index = kw_id + args.kernel_w * ( - kh_id + args.kernel_h * ( - ci_id + args.channels * ( - co_id))); + const auto kernel_index + = (args.kernel_mode == KernelMode::kConvolution) + ? (args.kernel_w - kw_id - 1) + args.kernel_w * ( + (args.kernel_h - kh_id - 1) + args.kernel_h * ( + ci_id + args.channels * ( + co_id))) + : kw_id + args.kernel_w * ( + kh_id + args.kernel_h * ( + ci_id + args.channels * ( + co_id))); const auto kernel_value = buffers_host.b_mat[kernel_index + args.b_offset]; result += input_value * kernel_value; diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp index acf7998b..2a3577c3 100644 --- a/test/routines/levelx/xim2col.hpp +++ b/test/routines/levelx/xim2col.hpp @@ -31,7 +31,8 @@ public: // The list of arguments relevant for this routine static std::vector<std::string> GetOptions() { - return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, + return {kArgKernelMode, + kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, kArgAOffset, kArgBOffset}; } @@ -87,7 +88,8 @@ public: #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; - auto status = Im2col<T>(args.channels, args.height, args.width, + auto status = Im2col<T>(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -97,7 +99,8 @@ public: &queue_plain, &event); if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API - auto status = Im2col<T>(args.channels, args.height, args.width, + auto status = Im2col<T>(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -175,7 +178,10 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) } // Sets the output value - const auto kernel_index = kw_id + args.kernel_w * kh_id; + const auto kernel_index + = (args.kernel_mode == KernelMode::kConvolution) + ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1 + : kw_id + args.kernel_w * kh_id; const auto patch_index = w_id + col_w * h_id; const auto col_index = patch_index + kernel_index * col_w * col_h + c_id * col_w * col_h * args.kernel_h * args.kernel_w; |