diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/clblast.cpp | 57 | ||||
-rw-r--r-- | src/clblast_c.cpp | 78 | ||||
-rw-r--r-- | src/clblast_cuda.cpp | 57 | ||||
-rw-r--r-- | src/clblast_netlib_c.cpp | 48 | ||||
-rw-r--r-- | src/kernels/levelx/col2im.opencl | 73 | ||||
-rw-r--r-- | src/kernels/levelx/im2col.opencl | 59 | ||||
-rw-r--r-- | src/routines/levelx/xcol2im.cpp | 8 | ||||
-rw-r--r-- | src/routines/levelx/xcol2im.hpp | 3 | ||||
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 6 | ||||
-rw-r--r-- | src/routines/levelx/xconvgemm.hpp | 3 | ||||
-rw-r--r-- | src/routines/levelx/xim2col.cpp | 14 | ||||
-rw-r--r-- | src/routines/levelx/xim2col.hpp | 3 | ||||
-rw-r--r-- | src/utilities/utilities.cpp | 8 | ||||
-rw-r--r-- | src/utilities/utilities.hpp | 2 |
14 files changed, 304 insertions, 115 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index e45f504a..180693e7 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2218,79 +2218,94 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template <typename T> -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { auto queue_cpp = Queue(*queue); auto routine = Xim2col<T>(queue_cpp, event); - routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(im_buffer), im_offset, Buffer<T>(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template <typename T> -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { auto queue_cpp = Queue(*queue); auto routine = Xcol2im<T>(queue_cpp, event); - routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoCol2im(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(col_buffer), col_offset, Buffer<T>(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template <typename T> -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, @@ -2298,24 +2313,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid try { auto queue_cpp = Queue(*queue); auto routine = Xconvgemm<T>(queue_cpp, event); - routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + routine.DoConvgemm(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer<T>(im_buffer), im_offset, Buffer<T>(kernel_buffer), kernel_offset, Buffer<T>(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 645a69b1..a224230a 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3613,65 +3613,75 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran } // IM2COL -CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<half>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) @@ -3680,65 +3690,75 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con } // COL2IM -CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<half>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) @@ -3747,14 +3767,16 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con } // CONVGEMM -CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Convgemm<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, @@ -3762,14 +3784,16 @@ CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, c ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Convgemm<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, @@ -3777,14 +3801,16 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Convgemm<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm<half>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 03d995ba..264f360d 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2314,7 +2314,8 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template <typename T> -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr im_buffer, const size_t im_offset, CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device) { @@ -2323,36 +2324,43 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xim2col<T>(queue_cpp, nullptr); - routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(im_buffer), im_offset, Buffer<T>(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template <typename T> -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr col_buffer, const size_t col_offset, CUdeviceptr im_buffer, const size_t im_offset, const CUcontext context, const CUdevice device) { @@ -2361,36 +2369,43 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xcol2im<T>(queue_cpp, nullptr); - routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoCol2im(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(col_buffer), col_offset, Buffer<T>(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template <typename T> -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const CUdeviceptr im_buffer, const size_t im_offset, const CUdeviceptr kernel_buffer, const size_t kernel_offset, CUdeviceptr result_buffer, const size_t result_offset, @@ -2400,24 +2415,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xconvgemm<T>(queue_cpp, nullptr); - routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + routine.DoConvgemm(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer<T>(im_buffer), im_offset, Buffer<T>(kernel_buffer), kernel_offset, Buffer<T>(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, diff --git a/src/clblast_netlib_c.cpp b/src/clblast_netlib_c.cpp index 22570535..3a8f729e 100644 --- a/src/clblast_netlib_c.cpp +++ b/src/clblast_netlib_c.cpp @@ -4878,7 +4878,8 @@ void cblas_zomatcopy(const CLBlastLayout layout, const CLBlastTranspose a_transp } // IM2COL -void cblas_sim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_sim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const float* im, float* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4891,7 +4892,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast<const float*>(im)); col_buffer.Write(queue, col_size, reinterpret_cast<float*>(col)); auto queue_cl = queue(); - auto s = clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4900,7 +4902,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const } col_buffer.Read(queue, col_size, reinterpret_cast<float*>(col)); } -void cblas_dim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_dim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const double* im, double* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4913,7 +4916,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast<const double*>(im)); col_buffer.Write(queue, col_size, reinterpret_cast<double*>(col)); auto queue_cl = queue(); - auto s = clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4922,7 +4926,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const } col_buffer.Read(queue, col_size, reinterpret_cast<double*>(col)); } -void cblas_cim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_cim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* im, void* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4935,7 +4940,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast<const float2*>(im)); col_buffer.Write(queue, col_size, reinterpret_cast<float2*>(col)); auto queue_cl = queue(); - auto s = clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4944,7 +4950,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const } col_buffer.Read(queue, col_size, reinterpret_cast<float2*>(col)); } -void cblas_zim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_zim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* im, void* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4957,7 +4964,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast<const double2*>(im)); col_buffer.Write(queue, col_size, reinterpret_cast<double2*>(col)); auto queue_cl = queue(); - auto s = clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4968,7 +4976,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const } // COL2IM -void cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_scol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const float* col, float* im) { OPTIONAL_STATIC auto device = get_device(); @@ -4981,7 +4990,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast<const float*>(col)); im_buffer.Write(queue, im_size, reinterpret_cast<float*>(im)); auto queue_cl = queue(); - auto s = clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); @@ -4990,7 +5000,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const } im_buffer.Read(queue, im_size, reinterpret_cast<float*>(im)); } -void cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_dcol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const double* col, double* im) { OPTIONAL_STATIC auto device = get_device(); @@ -5003,7 +5014,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast<const double*>(col)); im_buffer.Write(queue, im_size, reinterpret_cast<double*>(im)); auto queue_cl = queue(); - auto s = clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); @@ -5012,7 +5024,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const } im_buffer.Read(queue, im_size, reinterpret_cast<double*>(im)); } -void cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_ccol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* col, void* im) { OPTIONAL_STATIC auto device = get_device(); @@ -5025,7 +5038,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast<const float2*>(col)); im_buffer.Write(queue, im_size, reinterpret_cast<float2*>(im)); auto queue_cl = queue(); - auto s = clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); @@ -5034,7 +5048,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const } im_buffer.Read(queue, im_size, reinterpret_cast<float2*>(im)); } -void cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_zcol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* col, void* im) { OPTIONAL_STATIC auto device = get_device(); @@ -5047,7 +5062,8 @@ void cblas_zcol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast<const double2*>(col)); im_buffer.Write(queue, im_size, reinterpret_cast<double2*>(im)); auto queue_cl = queue(); - auto s = clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl index a37db24f..484a7a98 100644 --- a/src/kernels/levelx/col2im.opencl +++ b/src/kernels/levelx/col2im.opencl @@ -28,18 +28,20 @@ inline int grid_ceil(const int x, const int step) { return x > 0 ? ((x - 1) / step + 1) * step : x / step * step; } +// Main body of the kernel __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void col2im(const int input_h, const int input_w, const int channels, - const int output_h, const int output_w, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int stride_bez_h, const int stride_bez_w, - const int dilation_bez_h, const int dilation_bez_w, - const int gcd_h, const int gcd_w, - const __global real* restrict col_buffer, const int col_offset, - __global real* im_buffer, const int im_offset) { +INLINE_FUNC void Xcol2im(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const bool kernel_flip, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { const int input_h_scaled = (input_h - 1) / gcd_h + 1; @@ -71,8 +73,9 @@ void col2im(const int input_h, const int input_w, const int channels, const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w; const int h_id = th / stride_h + stride_bez_h * gcd_scale_h; const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w; - - const int kernel_index = kw_id + kernel_w * kh_id; + const int kernel_index = (kernel_flip) + ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1 + : kw_id + kernel_w * kh_id; const int patch_index = w_id + output_w * h_id; const int output_index = patch_index + kernel_index * output_w * output_h + c_id * output_w * output_h * kernel_h * kernel_w; @@ -89,6 +92,50 @@ void col2im(const int input_h, const int input_w, const int channels, // ================================================================================================= +// Kernel flip version of the Xcol2im kernel (for convolution) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xcol2imKernelFlip(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { + const bool kernel_flip = true; + Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w, + kernel_flip, + col_buffer, col_offset, im_buffer, im_offset); +} + +// Normal version of the Xcol2im kernel (for cross-correlation) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xcol2imKernelNormal(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { + const bool kernel_flip = false; + Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w, + kernel_flip, + col_buffer, col_offset, im_buffer, im_offset); +} + +// ================================================================================================= + // End of the C++11 raw string literal )" diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl index 301e076b..5db4cb5f 100644 --- a/src/kernels/levelx/im2col.opencl +++ b/src/kernels/levelx/im2col.opencl @@ -25,15 +25,16 @@ R"( // ================================================================================================= -__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void im2col(const int input_h, const int input_w, const int channels, - const int output_h, const int output_w, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const __global real* restrict im_buffer, const int im_offset, - __global real* col_buffer, const int col_offset) { +// Main body of the kernel +INLINE_FUNC void Xim2col(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const bool kernel_flip, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { // Thread IDs const int w_id = get_global_id(0); // image width, max 'output_w' @@ -58,7 +59,9 @@ void im2col(const int input_h, const int input_w, const int channels, } // Sets the output value - const int kernel_index = kw_id + kernel_w * kh_id; + const int kernel_index = (kernel_flip) + ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1 + : kw_id + kernel_w * kh_id; const int patch_index = w_id + output_w * h_id; const int output_index = patch_index + kernel_index * output_w * output_h + c_id * output_w * output_h * kernel_h * kernel_w; @@ -70,6 +73,42 @@ void im2col(const int input_h, const int input_w, const int channels, // ================================================================================================= +// Kernel flip version of the Xim2col kernel (for convolution) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xim2colKernelFlip(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { + const bool kernel_flip = true; + Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + kernel_flip, + im_buffer, im_offset, col_buffer, col_offset); +} + +// Normal version of the Xim2col kernel (for cross-correlation) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xim2colKernelNormal(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { + const bool kernel_flip = false; + Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + kernel_flip, + im_buffer, im_offset, col_buffer, col_offset); +} + +// ================================================================================================= + // End of the C++11 raw string literal )" diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp index 7a0c36b7..d285e5c0 100644 --- a/src/routines/levelx/xcol2im.cpp +++ b/src/routines/levelx/xcol2im.cpp @@ -31,13 +31,17 @@ Xcol2im<T>::Xcol2im(Queue &queue, EventPointer event, const std::string &name): // The main routine template <typename T> -void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size_t width, +void Xcol2im<T>::DoCol2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const Buffer<T> &col_buffer, const size_t col_offset, const Buffer<T> &im_buffer, const size_t im_offset) { + // Flip the output along kernel_h and kernel_w, or not. + const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xcol2imKernelFlip" : "Xcol2imKernelNormal"; + // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -59,7 +63,7 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w); // Retrieves the kernel from the compiled binary - auto kernel = Kernel(program_, "col2im"); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(height)); diff --git a/src/routines/levelx/xcol2im.hpp b/src/routines/levelx/xcol2im.hpp index 86d68c45..522c717e 100644 --- a/src/routines/levelx/xcol2im.hpp +++ b/src/routines/levelx/xcol2im.hpp @@ -29,7 +29,8 @@ class Xcol2im: public Routine { Xcol2im(Queue &queue, EventPointer event, const std::string &name = "COL2IM"); // Templated-precision implementation of the routine - void DoCol2im(const size_t channels, const size_t height, const size_t width, + void DoCol2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index f26f23a7..88127b0f 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -43,7 +43,8 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam // ================================================================================================= template <typename T> -void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const size_t width, +void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, @@ -94,7 +95,8 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const const auto col_batch_offset = batch_id * patch_size * num_patches; auto im2col_event = Event(); auto im2col = Xim2col<T>(queue_, im2col_event.pointer()); - im2col.DoIm2col(channels, height, width, kernel_h, kernel_w, + im2col.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_batch_offset, col_buffer, col_batch_offset); diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp index 9d11ccee..20cfff60 100644 --- a/src/routines/levelx/xconvgemm.hpp +++ b/src/routines/levelx/xconvgemm.hpp @@ -32,7 +32,8 @@ class Xconvgemm: public Routine { const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col); // Templated-precision implementation of the routine - void DoConvgemm(const size_t channels, const size_t height, const size_t width, + void DoConvgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp index 09dcc42c..0f786974 100644 --- a/src/routines/levelx/xim2col.cpp +++ b/src/routines/levelx/xim2col.cpp @@ -22,22 +22,26 @@ namespace clblast { // Constructor: forwards to base class constructor template <typename T> Xim2col<T>::Xim2col(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, { -#include "../../kernels/levelx/im2col.opencl" - }) { + Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, { + #include "../../kernels/levelx/im2col.opencl" + }) { } // ================================================================================================= // The main routine template <typename T> -void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size_t width, +void Xim2col<T>::DoIm2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const Buffer<T> &im_buffer, const size_t im_offset, const Buffer<T> &col_buffer, const size_t col_offset) { + // Flip the output along kernel_h and kernel_w, or not. + const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xim2colKernelFlip" : "Xim2colKernelNormal"; + // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -50,7 +54,7 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; // Retrieves the kernel from the compiled binary - auto kernel = Kernel(program_, "im2col"); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(height)); diff --git a/src/routines/levelx/xim2col.hpp b/src/routines/levelx/xim2col.hpp index 2c03b169..77cc32eb 100644 --- a/src/routines/levelx/xim2col.hpp +++ b/src/routines/levelx/xim2col.hpp @@ -29,7 +29,8 @@ class Xim2col: public Routine { Xim2col(Queue &queue, EventPointer event, const std::string &name = "IM2COL"); // Templated-precision implementation of the routine - void DoIm2col(const size_t channels, const size_t height, const size_t width, + void DoIm2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index a6cd82e7..a0e89c98 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -175,6 +175,13 @@ std::string ToString(Precision value) { } } template <> +std::string ToString(KernelMode value) { + switch(value) { + case KernelMode::kCrossCorrelation: return ToString(static_cast<int>(value))+" (cross-correlation)"; + case KernelMode::kConvolution: return ToString(static_cast<int>(value))+" (convolution)"; + } +} +template <> std::string ToString(StatusCode value) { return std::to_string(static_cast<int>(value)); } @@ -281,6 +288,7 @@ template Side GetArgument<Side>(const std::vector<std::string>&, std::string&, c template Triangle GetArgument<Triangle>(const std::vector<std::string>&, std::string&, const std::string&, const Triangle); template Diagonal GetArgument<Diagonal>(const std::vector<std::string>&, std::string&, const std::string&, const Diagonal); template Precision GetArgument<Precision>(const std::vector<std::string>&, std::string&, const std::string&, const Precision); +template KernelMode GetArgument<KernelMode>(const std::vector<std::string>&, std::string&, const std::string&, const KernelMode); // ================================================================================================= diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index fcc1c57f..23486d35 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -69,6 +69,7 @@ constexpr auto kArgBTransp = "transB"; constexpr auto kArgSide = "side"; constexpr auto kArgTriangle = "triangle"; constexpr auto kArgDiagonal = "diagonal"; +constexpr auto kArgKernelMode = "kernel_mode"; constexpr auto kArgXInc = "incx"; constexpr auto kArgYInc = "incy"; constexpr auto kArgXOffset = "offx"; @@ -183,6 +184,7 @@ struct Arguments { Side side = Side::kLeft; Triangle triangle = Triangle::kUpper; Diagonal diagonal = Diagonal::kUnit; + KernelMode kernel_mode = KernelMode::kCrossCorrelation; size_t x_inc = 1; size_t y_inc = 1; size_t x_offset = 0; |