diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/clblast.h | 7 | ||||
-rw-r--r-- | include/clblast_c.h | 22 | ||||
-rw-r--r-- | include/clblast_cuda.h | 7 | ||||
-rw-r--r-- | include/clblast_netlib_c.h | 14 |
4 files changed, 50 insertions, 0 deletions
diff --git a/include/clblast.h b/include/clblast.h index 9a8988e7..27adf7fa 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -636,6 +636,13 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event = nullptr); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +template <typename T> +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event = nullptr); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template <typename T> StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, diff --git a/include/clblast_c.h b/include/clblast_c.h index 2357182c..1c681bfe 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -1410,6 +1410,28 @@ CLBlastStatusCode PUBLIC_API CLBlastHim2col(const size_t channels, const size_t cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +CLBlastStatusCode PUBLIC_API CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h index 1bbd898e..58f9b74b 100644 --- a/include/clblast_cuda.h +++ b/include/clblast_cuda.h @@ -608,6 +608,13 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +template <typename T> +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const CUdeviceptr col_buffer, const size_t col_offset, + CUdeviceptr im_buffer, const size_t im_offset, + const CUcontext context, const CUdevice device); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template <typename T> StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, diff --git a/include/clblast_netlib_c.h b/include/clblast_netlib_c.h index b64b82eb..65545bfb 100644 --- a/include/clblast_netlib_c.h +++ b/include/clblast_netlib_c.h @@ -960,6 +960,20 @@ void PUBLIC_API cblas_zim2col(const int channels, const int height, const int wi const void* im, void* col); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +void PUBLIC_API cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const float* col, + float* im); +void PUBLIC_API cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const double* col, + double* im); +void PUBLIC_API cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const void* col, + void* im); +void PUBLIC_API cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const void* col, + void* im); + // ================================================================================================= #ifdef __cplusplus |