diff options
author | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
---|---|---|
committer | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
commit | 032e3b0cc00a15dd2af8b4fb82d261eb7b086e26 (patch) | |
tree | cdcf4d0fc342c9ff92ee7ab3f75b0cdeced46e96 /src/routines/levelx/xim2col.cpp | |
parent | 90112618daa0d6b24ae3e53203a636d2e908dfba (diff) |
Add kernel_mode option to im2col, col2im, and convgemm functions
Diffstat (limited to 'src/routines/levelx/xim2col.cpp')
-rw-r--r-- | src/routines/levelx/xim2col.cpp | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp index 09dcc42c..0f786974 100644 --- a/src/routines/levelx/xim2col.cpp +++ b/src/routines/levelx/xim2col.cpp @@ -22,22 +22,26 @@ namespace clblast { // Constructor: forwards to base class constructor template <typename T> Xim2col<T>::Xim2col(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, { -#include "../../kernels/levelx/im2col.opencl" - }) { + Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, { + #include "../../kernels/levelx/im2col.opencl" + }) { } // ================================================================================================= // The main routine template <typename T> -void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size_t width, +void Xim2col<T>::DoIm2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const Buffer<T> &im_buffer, const size_t im_offset, const Buffer<T> &col_buffer, const size_t col_offset) { + // Flip the output along kernel_h and kernel_w, or not. + const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xim2colKernelFlip" : "Xim2colKernelNormal"; + // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -50,7 +54,7 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; // Retrieves the kernel from the compiled binary - auto kernel = Kernel(program_, "im2col"); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(height)); |