diff options
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r-- | src/clblast.cpp | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index ca401066..9089b17c 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -72,6 +72,7 @@ // Level-x includes (non-BLAS) #include "routines/levelx/xomatcopy.hpp" +#include "routines/levelx/xim2col.hpp" #include "routines/levelx/xaxpybatched.hpp" #include "routines/levelx/xgemmbatched.hpp" @@ -2212,6 +2213,42 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL +template <typename T> +StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem im_buffer, const size_t im_offset, + cl_mem col_buffer, const size_t col_offset, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = Xim2col<T>(queue_cpp, event); + routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + Buffer<T>(im_buffer), im_offset, + Buffer<T>(col_buffer), col_offset); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); + // Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED template <typename T> StatusCode AxpyBatched(const size_t n, |