diff options
author | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
---|---|---|
committer | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
commit | 032e3b0cc00a15dd2af8b4fb82d261eb7b086e26 (patch) | |
tree | cdcf4d0fc342c9ff92ee7ab3f75b0cdeced46e96 /src/clblast_c.cpp | |
parent | 90112618daa0d6b24ae3e53203a636d2e908dfba (diff) |
Add kernel_mode option to im2col, col2im, and convgemm functions
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r-- | src/clblast_c.cpp | 78 |
1 files changed, 52 insertions, 26 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 645a69b1..a224230a 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3613,65 +3613,75 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran } // IM2COL -CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Im2col<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col<half>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) @@ -3680,65 +3690,75 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con } // COL2IM -CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Col2im<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im<half>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) @@ -3747,14 +3767,16 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con } // CONVGEMM -CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Convgemm<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm<float>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, @@ -3762,14 +3784,16 @@ CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, c ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Convgemm<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm<double>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, @@ -3777,14 +3801,16 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c ); } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast<CLBlastStatusCode>( - clblast::Convgemm<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm<half>(static_cast<clblast::KernelMode>(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, |