From 032e3b0cc00a15dd2af8b4fb82d261eb7b086e26 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 12 Nov 2018 10:12:07 +0900 Subject: Add kernel_mode option to im2col, col2im, and convgemm functions --- doc/api.md | 51 ++++++++++++------ include/clblast.h | 10 ++-- include/clblast_c.h | 40 +++++++++----- include/clblast_cuda.h | 10 ++-- include/clblast_netlib_c.h | 25 ++++++--- scripts/generator/generator.py | 8 +-- scripts/generator/generator/convert.py | 2 + src/clblast.cpp | 57 +++++++++++++------- src/clblast_c.cpp | 78 ++++++++++++++++++--------- src/clblast_cuda.cpp | 57 +++++++++++++------- src/clblast_netlib_c.cpp | 48 +++++++++++------ src/kernels/levelx/col2im.opencl | 73 ++++++++++++++++++++----- src/kernels/levelx/im2col.opencl | 59 ++++++++++++++++---- src/routines/levelx/xcol2im.cpp | 8 ++- src/routines/levelx/xcol2im.hpp | 3 +- src/routines/levelx/xconvgemm.cpp | 6 ++- src/routines/levelx/xconvgemm.hpp | 3 +- src/routines/levelx/xim2col.cpp | 14 +++-- src/routines/levelx/xim2col.hpp | 3 +- src/utilities/utilities.cpp | 8 +++ src/utilities/utilities.hpp | 2 + test/correctness/misc/override_parameters.cpp | 1 + test/correctness/testblas.hpp | 38 +++++++------ test/routines/levelx/xcol2im.hpp | 14 +++-- test/routines/levelx/xconvgemm.hpp | 20 ++++--- test/routines/levelx/xim2col.hpp | 14 +++-- 26 files changed, 461 insertions(+), 191 deletions(-) diff --git a/doc/api.md b/doc/api.md index 337b5af9..996505f1 100644 --- a/doc/api.md +++ b/doc/api.md @@ -3020,7 +3020,8 @@ Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is th C++ API: ``` template -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) @@ -3028,23 +3029,28 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width C API: ``` -CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) @@ -3052,6 +3058,7 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con Arguments to IM2COL: +* `const KernelMode kernel_mode`: The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes. * `const size_t channels`: Integer size argument. This value must be positive. * `const size_t height`: Integer size argument. This value must be positive. * `const size_t width`: Integer size argument. This value must be positive. @@ -3080,7 +3087,8 @@ Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is th C++ API: ``` template -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) @@ -3088,23 +3096,28 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width C API: ``` -CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) @@ -3112,6 +3125,7 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con Arguments to COL2IM: +* `const KernelMode kernel_mode`: The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes. * `const size_t channels`: Integer size argument. This value must be positive. * `const size_t height`: Integer size argument. This value must be positive. * `const size_t width`: Integer size argument. This value must be positive. @@ -3140,7 +3154,8 @@ Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D i C++ API: ``` template -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, @@ -3149,17 +3164,20 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid C API: ``` -CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) -CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, @@ -3168,6 +3186,7 @@ CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, c Arguments to CONVGEMM: +* `const KernelMode kernel_mode`: The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes. * `const size_t channels`: Integer size argument. This value must be positive. * `const size_t height`: Integer size argument. This value must be positive. * `const size_t width`: Integer size argument. This value must be positive. diff --git a/include/clblast.h b/include/clblast.h index 27adf7fa..7a82361c 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -117,6 +117,7 @@ enum class Transpose { kNo = 111, kYes = 112, kConjugate = 113 }; enum class Triangle { kUpper = 121, kLower = 122 }; enum class Diagonal { kNonUnit = 131, kUnit = 132 }; enum class Side { kLeft = 141, kRight = 142 }; +enum class KernelMode { kCrossCorrelation = 151, kConvolution = 152 }; // Precision scoped enum (values in bits) enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64, @@ -631,21 +632,24 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event = nullptr); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event = nullptr); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, diff --git a/include/clblast_c.h b/include/clblast_c.h index 1c681bfe..2ba6375a 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -120,6 +120,7 @@ typedef enum CLBlastTriangle_ { CLBlastTriangleUpper = 121, typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131, CLBlastDiagonalUnit = 132 } CLBlastDiagonal; typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide; +typedef enum CLBlastKernelMode_ { CLBlastKernelModeCrossCorrelation = 151, CLBlastKernelModeConvolution = 152 } CLBlastKernelMode; // Precision enum (values in bits) typedef enum CLBlastPrecision_ { CLBlastPrecisionHalf = 16, CLBlastPrecisionSingle = 32, @@ -1389,61 +1390,74 @@ CLBlastStatusCode PUBLIC_API CLBlastHomatcopy(const CLBlastLayout layout, const cl_command_queue* queue, cl_event* event); // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL -CLBlastStatusCode PUBLIC_API CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastSim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastDim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastCim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastZim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastHim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM -CLBlastStatusCode PUBLIC_API CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastScol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastDcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastCcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastZcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode PUBLIC_API CLBlastHcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM -CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode PUBLIC_API CLBlastDconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event); -CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h index 58f9b74b..f6d6372d 100644 --- a/include/clblast_cuda.h +++ b/include/clblast_cuda.h @@ -89,6 +89,7 @@ enum class Transpose { kNo = 111, kYes = 112, kConjugate = 113 }; enum class Triangle { kUpper = 121, kLower = 122 }; enum class Diagonal { kNonUnit = 131, kUnit = 132 }; enum class Side { kLeft = 141, kRight = 142 }; +enum class KernelMode { kCrossCorrelation = 151, kConvolution = 152 }; // Precision scoped enum (values in bits) enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64, @@ -603,21 +604,24 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr im_buffer, const size_t im_offset, CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr col_buffer, const size_t col_offset, CUdeviceptr im_buffer, const size_t im_offset, const CUcontext context, const CUdevice device); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const CUdeviceptr im_buffer, const size_t im_offset, const CUdeviceptr kernel_buffer, const size_t kernel_offset, CUdeviceptr result_buffer, const size_t result_offset, diff --git a/include/clblast_netlib_c.h b/include/clblast_netlib_c.h index 65545bfb..4c54fb18 100644 --- a/include/clblast_netlib_c.h +++ b/include/clblast_netlib_c.h @@ -45,6 +45,7 @@ typedef enum CLBlastTriangle_ { CLBlastTriangleUpper = 121, typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131, CLBlastDiagonalUnit = 132 } CLBlastDiagonal; typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide; +typedef enum CLBlastKernelMode_ { CLBlastKernelModeCrossCorrelation = 141, CLBlastKernelModeConvolution = 152 } CLBlastKernelMode; // For full compatibility with CBLAS typedef CLBlastLayout CBLAS_ORDER; @@ -947,30 +948,38 @@ void PUBLIC_API cblas_zomatcopy(const CLBlastLayout layout, const CLBlastTranspo void* b, const int b_ld); // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL -void PUBLIC_API cblas_sim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_sim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const float* im, float* col); -void PUBLIC_API cblas_dim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_dim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const double* im, double* col); -void PUBLIC_API cblas_cim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_cim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* im, void* col); -void PUBLIC_API cblas_zim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_zim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* im, void* col); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM -void PUBLIC_API cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_scol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const float* col, float* im); -void PUBLIC_API cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_dcol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const double* col, double* im); -void PUBLIC_API cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_ccol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* col, void* im); -void PUBLIC_API cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void PUBLIC_API cblas_zcol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* col, void* im); diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index f8022d81..68e3f01a 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -49,7 +49,7 @@ FILES = [ "/src/clblast_cuda.cpp", "/src/pyclblast/src/pyclblast.pyx" ] -HEADER_LINES = [123, 21, 127, 24, 29, 45, 29, 65, 40, 95, 21, 290] +HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 290] FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 232 @@ -180,9 +180,9 @@ ROUTINES = [ # Special routines: Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []), Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]), - Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix. Overwrites any existing values in the _col_ buffer", []), - Routine(True, True, 0, False, "x", "col2im", T, [S,D,C,Z,H], im2col_constants, [], ["col"], ["im"], [col,im], [""], "", "Col2im function (non-BLAS function)", "Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix. Accumulates results on top of the existing values in the _im_ buffer.", []), - Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []), + Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, ["kernel_mode"], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix. Overwrites any existing values in the _col_ buffer", []), + Routine(True, True, 0, False, "x", "col2im", T, [S,D,C,Z,H], im2col_constants, ["kernel_mode"], ["col"], ["im"], [col,im], [""], "", "Col2im function (non-BLAS function)", "Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix. Accumulates results on top of the existing values in the _im_ buffer.", []), + Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, ["kernel_mode"], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []), # Batched routines: Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []), Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]), diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py index 07f45669..16890d27 100644 --- a/scripts/generator/generator/convert.py +++ b/scripts/generator/generator/convert.py @@ -27,6 +27,7 @@ def option_to_clblast(x): 'side': "Side", 'triangle': "Triangle", 'diagonal': "Diagonal", + 'kernel_mode': "KernelMode", }[x] @@ -79,4 +80,5 @@ def option_to_documentation(x): 'side': "The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).", 'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).", 'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.", + 'kernel_mode': "The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes.", }[x] diff --git a/src/clblast.cpp b/src/clblast.cpp index e45f504a..180693e7 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2218,79 +2218,94 @@ template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { auto queue_cpp = Queue(*queue); auto routine = Xim2col(queue_cpp, event); - routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer(im_buffer), im_offset, Buffer(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { auto queue_cpp = Queue(*queue); auto routine = Xcol2im(queue_cpp, event); - routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoCol2im(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer(col_buffer), col_offset, Buffer(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, @@ -2298,24 +2313,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid try { auto queue_cpp = Queue(*queue); auto routine = Xconvgemm(queue_cpp, event); - routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + routine.DoConvgemm(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer(im_buffer), im_offset, Buffer(kernel_buffer), kernel_offset, Buffer(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 645a69b1..a224230a 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3613,65 +3613,75 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran } // IM2COL -CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_offset, col_buffer, col_offset, queue, event) @@ -3680,65 +3690,75 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con } // COL2IM -CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer, col_offset, im_buffer, im_offset, queue, event) @@ -3747,14 +3767,16 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con } // CONVGEMM -CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Convgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, @@ -3762,14 +3784,16 @@ CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, c ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Convgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, @@ -3777,14 +3801,16 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c ); } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } -CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, cl_command_queue* queue, cl_event* event) { try { return static_cast( - clblast::Convgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + clblast::Convgemm(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, im_buffer, im_offset, kernel_buffer, kernel_offset, result_buffer, result_offset, diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 03d995ba..264f360d 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2314,7 +2314,8 @@ template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr im_buffer, const size_t im_offset, CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device) { @@ -2323,36 +2324,43 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xim2col(queue_cpp, nullptr); - routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer(im_buffer), im_offset, Buffer(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr col_buffer, const size_t col_offset, CUdeviceptr im_buffer, const size_t im_offset, const CUcontext context, const CUdevice device) { @@ -2361,36 +2369,43 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xcol2im(queue_cpp, nullptr); - routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoCol2im(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer(col_buffer), col_offset, Buffer(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const CUdeviceptr im_buffer, const size_t im_offset, const CUdeviceptr kernel_buffer, const size_t kernel_offset, CUdeviceptr result_buffer, const size_t result_offset, @@ -2400,24 +2415,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xconvgemm(queue_cpp, nullptr); - routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + routine.DoConvgemm(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer(im_buffer), im_offset, Buffer(kernel_buffer), kernel_offset, Buffer(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, diff --git a/src/clblast_netlib_c.cpp b/src/clblast_netlib_c.cpp index 22570535..3a8f729e 100644 --- a/src/clblast_netlib_c.cpp +++ b/src/clblast_netlib_c.cpp @@ -4878,7 +4878,8 @@ void cblas_zomatcopy(const CLBlastLayout layout, const CLBlastTranspose a_transp } // IM2COL -void cblas_sim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_sim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const float* im, float* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4891,7 +4892,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast(im)); col_buffer.Write(queue, col_size, reinterpret_cast(col)); auto queue_cl = queue(); - auto s = clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4900,7 +4902,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const } col_buffer.Read(queue, col_size, reinterpret_cast(col)); } -void cblas_dim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_dim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const double* im, double* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4913,7 +4916,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast(im)); col_buffer.Write(queue, col_size, reinterpret_cast(col)); auto queue_cl = queue(); - auto s = clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4922,7 +4926,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const } col_buffer.Read(queue, col_size, reinterpret_cast(col)); } -void cblas_cim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_cim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* im, void* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4935,7 +4940,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast(im)); col_buffer.Write(queue, col_size, reinterpret_cast(col)); auto queue_cl = queue(); - auto s = clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4944,7 +4950,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const } col_buffer.Read(queue, col_size, reinterpret_cast(col)); } -void cblas_zim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_zim2col(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* im, void* col) { OPTIONAL_STATIC auto device = get_device(); @@ -4957,7 +4964,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const im_buffer.Write(queue, im_size, reinterpret_cast(im)); col_buffer.Write(queue, col_size, reinterpret_cast(col)); auto queue_cl = queue(); - auto s = clblast::Im2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Im2col(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer(), 0, col_buffer(), 0, &queue_cl); @@ -4968,7 +4976,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const } // COL2IM -void cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_scol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const float* col, float* im) { OPTIONAL_STATIC auto device = get_device(); @@ -4981,7 +4990,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast(col)); im_buffer.Write(queue, im_size, reinterpret_cast(im)); auto queue_cl = queue(); - auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); @@ -4990,7 +5000,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const } im_buffer.Read(queue, im_size, reinterpret_cast(im)); } -void cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_dcol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const double* col, double* im) { OPTIONAL_STATIC auto device = get_device(); @@ -5003,7 +5014,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast(col)); im_buffer.Write(queue, im_size, reinterpret_cast(im)); auto queue_cl = queue(); - auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); @@ -5012,7 +5024,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const } im_buffer.Read(queue, im_size, reinterpret_cast(im)); } -void cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_ccol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* col, void* im) { OPTIONAL_STATIC auto device = get_device(); @@ -5025,7 +5038,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast(col)); im_buffer.Write(queue, im_size, reinterpret_cast(im)); auto queue_cl = queue(); - auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); @@ -5034,7 +5048,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const } im_buffer.Read(queue, im_size, reinterpret_cast(im)); } -void cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, +void cblas_zcol2im(const CLBlastKernelMode kernel_mode, + const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const void* col, void* im) { OPTIONAL_STATIC auto device = get_device(); @@ -5047,7 +5062,8 @@ void cblas_zcol2im(const int channels, const int height, const int width, const col_buffer.Write(queue, col_size, reinterpret_cast(col)); im_buffer.Write(queue, im_size, reinterpret_cast(im)); auto queue_cl = queue(); - auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + auto s = clblast::Col2im(static_cast(kernel_mode), + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, col_buffer(), 0, im_buffer(), 0, &queue_cl); diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl index a37db24f..484a7a98 100644 --- a/src/kernels/levelx/col2im.opencl +++ b/src/kernels/levelx/col2im.opencl @@ -28,18 +28,20 @@ inline int grid_ceil(const int x, const int step) { return x > 0 ? ((x - 1) / step + 1) * step : x / step * step; } +// Main body of the kernel __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void col2im(const int input_h, const int input_w, const int channels, - const int output_h, const int output_w, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int stride_bez_h, const int stride_bez_w, - const int dilation_bez_h, const int dilation_bez_w, - const int gcd_h, const int gcd_w, - const __global real* restrict col_buffer, const int col_offset, - __global real* im_buffer, const int im_offset) { +INLINE_FUNC void Xcol2im(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const bool kernel_flip, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { const int input_h_scaled = (input_h - 1) / gcd_h + 1; @@ -71,8 +73,9 @@ void col2im(const int input_h, const int input_w, const int channels, const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w; const int h_id = th / stride_h + stride_bez_h * gcd_scale_h; const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w; - - const int kernel_index = kw_id + kernel_w * kh_id; + const int kernel_index = (kernel_flip) + ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1 + : kw_id + kernel_w * kh_id; const int patch_index = w_id + output_w * h_id; const int output_index = patch_index + kernel_index * output_w * output_h + c_id * output_w * output_h * kernel_h * kernel_w; @@ -89,6 +92,50 @@ void col2im(const int input_h, const int input_w, const int channels, // ================================================================================================= +// Kernel flip version of the Xcol2im kernel (for convolution) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xcol2imKernelFlip(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { + const bool kernel_flip = true; + Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w, + kernel_flip, + col_buffer, col_offset, im_buffer, im_offset); +} + +// Normal version of the Xcol2im kernel (for cross-correlation) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xcol2imKernelNormal(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { + const bool kernel_flip = false; + Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w, + kernel_flip, + col_buffer, col_offset, im_buffer, im_offset); +} + +// ================================================================================================= + // End of the C++11 raw string literal )" diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl index 301e076b..5db4cb5f 100644 --- a/src/kernels/levelx/im2col.opencl +++ b/src/kernels/levelx/im2col.opencl @@ -25,15 +25,16 @@ R"( // ================================================================================================= -__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void im2col(const int input_h, const int input_w, const int channels, - const int output_h, const int output_w, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const __global real* restrict im_buffer, const int im_offset, - __global real* col_buffer, const int col_offset) { +// Main body of the kernel +INLINE_FUNC void Xim2col(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const bool kernel_flip, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { // Thread IDs const int w_id = get_global_id(0); // image width, max 'output_w' @@ -58,7 +59,9 @@ void im2col(const int input_h, const int input_w, const int channels, } // Sets the output value - const int kernel_index = kw_id + kernel_w * kh_id; + const int kernel_index = (kernel_flip) + ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1 + : kw_id + kernel_w * kh_id; const int patch_index = w_id + output_w * h_id; const int output_index = patch_index + kernel_index * output_w * output_h + c_id * output_w * output_h * kernel_h * kernel_w; @@ -70,6 +73,42 @@ void im2col(const int input_h, const int input_w, const int channels, // ================================================================================================= +// Kernel flip version of the Xim2col kernel (for convolution) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xim2colKernelFlip(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { + const bool kernel_flip = true; + Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + kernel_flip, + im_buffer, im_offset, col_buffer, col_offset); +} + +// Normal version of the Xim2col kernel (for cross-correlation) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xim2colKernelNormal(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { + const bool kernel_flip = false; + Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + kernel_flip, + im_buffer, im_offset, col_buffer, col_offset); +} + +// ================================================================================================= + // End of the C++11 raw string literal )" diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp index 7a0c36b7..d285e5c0 100644 --- a/src/routines/levelx/xcol2im.cpp +++ b/src/routines/levelx/xcol2im.cpp @@ -31,13 +31,17 @@ Xcol2im::Xcol2im(Queue &queue, EventPointer event, const std::string &name): // The main routine template -void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size_t width, +void Xcol2im::DoCol2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const Buffer &col_buffer, const size_t col_offset, const Buffer &im_buffer, const size_t im_offset) { + // Flip the output along kernel_h and kernel_w, or not. + const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xcol2imKernelFlip" : "Xcol2imKernelNormal"; + // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -59,7 +63,7 @@ void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size EuclidGCD(static_cast(stride_w), static_cast(dilation_w), stride_bez_w, dilation_bez_w, gcd_w); // Retrieves the kernel from the compiled binary - auto kernel = Kernel(program_, "col2im"); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(height)); diff --git a/src/routines/levelx/xcol2im.hpp b/src/routines/levelx/xcol2im.hpp index 86d68c45..522c717e 100644 --- a/src/routines/levelx/xcol2im.hpp +++ b/src/routines/levelx/xcol2im.hpp @@ -29,7 +29,8 @@ class Xcol2im: public Routine { Xcol2im(Queue &queue, EventPointer event, const std::string &name = "COL2IM"); // Templated-precision implementation of the routine - void DoCol2im(const size_t channels, const size_t height, const size_t width, + void DoCol2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index f26f23a7..88127b0f 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -43,7 +43,8 @@ Xconvgemm::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam // ================================================================================================= template -void Xconvgemm::DoConvgemm(const size_t channels, const size_t height, const size_t width, +void Xconvgemm::DoConvgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, @@ -94,7 +95,8 @@ void Xconvgemm::DoConvgemm(const size_t channels, const size_t height, const const auto col_batch_offset = batch_id * patch_size * num_patches; auto im2col_event = Event(); auto im2col = Xim2col(queue_, im2col_event.pointer()); - im2col.DoIm2col(channels, height, width, kernel_h, kernel_w, + im2col.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_batch_offset, col_buffer, col_batch_offset); diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp index 9d11ccee..20cfff60 100644 --- a/src/routines/levelx/xconvgemm.hpp +++ b/src/routines/levelx/xconvgemm.hpp @@ -32,7 +32,8 @@ class Xconvgemm: public Routine { const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col); // Templated-precision implementation of the routine - void DoConvgemm(const size_t channels, const size_t height, const size_t width, + void DoConvgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp index 09dcc42c..0f786974 100644 --- a/src/routines/levelx/xim2col.cpp +++ b/src/routines/levelx/xim2col.cpp @@ -22,22 +22,26 @@ namespace clblast { // Constructor: forwards to base class constructor template Xim2col::Xim2col(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy"}, PrecisionValue(), {}, { -#include "../../kernels/levelx/im2col.opencl" - }) { + Routine(queue, event, name, {"Copy"}, PrecisionValue(), {}, { + #include "../../kernels/levelx/im2col.opencl" + }) { } // ================================================================================================= // The main routine template -void Xim2col::DoIm2col(const size_t channels, const size_t height, const size_t width, +void Xim2col::DoIm2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const Buffer &im_buffer, const size_t im_offset, const Buffer &col_buffer, const size_t col_offset) { + // Flip the output along kernel_h and kernel_w, or not. + const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xim2colKernelFlip" : "Xim2colKernelNormal"; + // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -50,7 +54,7 @@ void Xim2col::DoIm2col(const size_t channels, const size_t height, const size const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; // Retrieves the kernel from the compiled binary - auto kernel = Kernel(program_, "im2col"); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(height)); diff --git a/src/routines/levelx/xim2col.hpp b/src/routines/levelx/xim2col.hpp index 2c03b169..77cc32eb 100644 --- a/src/routines/levelx/xim2col.hpp +++ b/src/routines/levelx/xim2col.hpp @@ -29,7 +29,8 @@ class Xim2col: public Routine { Xim2col(Queue &queue, EventPointer event, const std::string &name = "IM2COL"); // Templated-precision implementation of the routine - void DoIm2col(const size_t channels, const size_t height, const size_t width, + void DoIm2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index a6cd82e7..a0e89c98 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -175,6 +175,13 @@ std::string ToString(Precision value) { } } template <> +std::string ToString(KernelMode value) { + switch(value) { + case KernelMode::kCrossCorrelation: return ToString(static_cast(value))+" (cross-correlation)"; + case KernelMode::kConvolution: return ToString(static_cast(value))+" (convolution)"; + } +} +template <> std::string ToString(StatusCode value) { return std::to_string(static_cast(value)); } @@ -281,6 +288,7 @@ template Side GetArgument(const std::vector&, std::string&, c template Triangle GetArgument(const std::vector&, std::string&, const std::string&, const Triangle); template Diagonal GetArgument(const std::vector&, std::string&, const std::string&, const Diagonal); template Precision GetArgument(const std::vector&, std::string&, const std::string&, const Precision); +template KernelMode GetArgument(const std::vector&, std::string&, const std::string&, const KernelMode); // ================================================================================================= diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index fcc1c57f..23486d35 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -69,6 +69,7 @@ constexpr auto kArgBTransp = "transB"; constexpr auto kArgSide = "side"; constexpr auto kArgTriangle = "triangle"; constexpr auto kArgDiagonal = "diagonal"; +constexpr auto kArgKernelMode = "kernel_mode"; constexpr auto kArgXInc = "incx"; constexpr auto kArgYInc = "incy"; constexpr auto kArgXOffset = "offx"; @@ -183,6 +184,7 @@ struct Arguments { Side side = Side::kLeft; Triangle triangle = Triangle::kUpper; Diagonal diagonal = Diagonal::kUnit; + KernelMode kernel_mode = KernelMode::kCrossCorrelation; size_t x_inc = 1; size_t y_inc = 1; size_t x_offset = 0; diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp index 54229c5e..7ed4faff 100644 --- a/test/correctness/misc/override_parameters.cpp +++ b/test/correctness/misc/override_parameters.cpp @@ -60,6 +60,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st args.layout = GetArgument(arguments, help, kArgLayout, Layout::kRowMajor); args.a_transpose = GetArgument(arguments, help, kArgATransp, Transpose::kNo); args.b_transpose = GetArgument(arguments, help, kArgBTransp, Transpose::kNo); + args.kernel_mode = GetArgument(arguments, help, kArgKernelMode, KernelMode::kCrossCorrelation); args.alpha = GetArgument(arguments, help, kArgAlpha, GetScalar()); args.beta = GetArgument(arguments, help, kArgBeta, GetScalar()); diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 137df30f..b2dc6e7a 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -63,6 +63,7 @@ class TestBlas: public Tester { static const std::vector kNumKernels; static const std::vector kStrideValues; static const std::vector kChannelValues; + static const std::vector kKernelModes; const std::vector kOffsets; const std::vector kAlphaValues; const std::vector kBetaValues; @@ -142,6 +143,7 @@ template const std::vector TestBlas::kKern template const std::vector TestBlas::kNumKernels = { 1, 6 }; template const std::vector TestBlas::kStrideValues = { 1, 3 }; template const std::vector TestBlas::kChannelValues = { 1, 2 }; +template const std::vector TestBlas::kKernelModes = { KernelMode::kCrossCorrelation, KernelMode::kConvolution }; // Test settings for the invalid tests template const std::vector TestBlas::kInvalidIncrements = { 0, 1 }; @@ -168,6 +170,7 @@ static StatusCode ReferenceNotAvailable(const Arguments &, BufferType &, Queu template void handle_remaining_of_options(std::vector> ®ular_test_vector, Arguments &r_args, TestBlas &tester, + const std::vector &kernel_modes, const std::vector &channelss, const std::vector &heights, const std::vector &widths, @@ -181,21 +184,23 @@ void handle_remaining_of_options(std::vector> ®ular_test_vector, const std::vector &dilation_ws, const std::vector &batch_counts, const std::vector &num_kernelss) { - for (auto &channels: channelss) { r_args.channels = channels; - for (auto &height: heights) { r_args.height = height; - for (auto &width: widths) { r_args.width = width; - for (auto &kernel_h: kernel_hs) { r_args.kernel_h = kernel_h; - for (auto &kernel_w: kernel_ws) { r_args.kernel_w = kernel_w; - for (auto &pad_h: pad_hs) { r_args.pad_h = pad_h; - for (auto &pad_w: pad_ws) { r_args.pad_w = pad_w; - for (auto &stride_h: stride_hs) { r_args.stride_h = stride_h; - for (auto &stride_w: stride_ws) { r_args.stride_w = stride_w; - for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h; - for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w; - for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count; - for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels; - C::SetSizes(r_args, tester.queue_); - regular_test_vector.push_back(r_args); + for (auto &kernel_mode: kernel_modes) { r_args.kernel_mode = kernel_mode; + for (auto &channels: channelss) { r_args.channels = channels; + for (auto &height: heights) { r_args.height = height; + for (auto &width: widths) { r_args.width = width; + for (auto &kernel_h: kernel_hs) { r_args.kernel_h = kernel_h; + for (auto &kernel_w: kernel_ws) { r_args.kernel_w = kernel_w; + for (auto &pad_h: pad_hs) { r_args.pad_h = pad_h; + for (auto &pad_w: pad_ws) { r_args.pad_w = pad_w; + for (auto &stride_h: stride_hs) { r_args.stride_h = stride_h; + for (auto &stride_w: stride_ws) { r_args.stride_w = stride_w; + for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h; + for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w; + for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count; + for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels; + C::SetSizes(r_args, tester.queue_); + regular_test_vector.push_back(r_args); + } } } } @@ -284,6 +289,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na auto imax_offsets = std::vector{args.imax_offset}; auto alphas = std::vector{args.alpha}; auto betas = std::vector{args.beta}; + auto kernel_modes = std::vector{args.kernel_mode}; auto channelss = std::vector{args.channels}; auto heights = std::vector{args.height}; auto widths = std::vector{args.width}; @@ -340,6 +346,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; } if (option == kArgAlpha) { alphas = tester.kAlphaValues; } if (option == kArgBeta) { betas = tester.kBetaValues; } + if (option == kArgKernelMode) { kernel_modes = tester.kKernelModes; } if (option == kArgChannels) { channelss = tester.kChannelValues; } if (option == kArgHeight) { heights = tester.kMatrixDims; } if (option == kArgWidth) { widths = tester.kMatrixDims; } @@ -397,6 +404,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na for (auto &beta: betas) { r_args.beta = beta; // Cannot have more for-loops because of MSVC's C1061 error handle_remaining_of_options(regular_test_vector, r_args, tester, + kernel_modes, channelss, heights, widths, kernel_hs, kernel_ws, pad_hs, pad_ws, stride_hs, stride_ws, dilation_hs, dilation_ws, diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp index 176fceae..c28727e7 100644 --- a/test/routines/levelx/xcol2im.hpp +++ b/test/routines/levelx/xcol2im.hpp @@ -31,7 +31,8 @@ public: // The list of arguments relevant for this routine static std::vector GetOptions() { - return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, + return {kArgKernelMode, + kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, kArgAOffset, kArgBOffset}; } @@ -87,7 +88,8 @@ public: #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; - auto status = Col2im(args.channels, args.height, args.width, + auto status = Col2im(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -97,7 +99,8 @@ public: &queue_plain, &event); if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API - auto status = Col2im(args.channels, args.height, args.width, + auto status = Col2im(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -167,7 +170,10 @@ StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) for (auto w_id = size_t{0}; w_id < col_w; ++w_id) { // image width // Reads the input value - const auto kernel_index = kw_id + args.kernel_w * kh_id; + const auto kernel_index + = (args.kernel_mode == KernelMode::kConvolution) + ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1 + : kw_id + args.kernel_w * kh_id; const auto patch_index = w_id + col_w * h_id; const auto col_index = patch_index + kernel_index * col_w * col_h + c_id * col_w * col_h * args.kernel_h * args.kernel_w; diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp index 7fa4e701..e67b8174 100644 --- a/test/routines/levelx/xconvgemm.hpp +++ b/test/routines/levelx/xconvgemm.hpp @@ -91,7 +91,8 @@ public: #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; - auto status = Convgemm(args.channels, args.height, args.width, + auto status = Convgemm(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -103,7 +104,8 @@ public: &queue_plain, &event); if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API - auto status = Convgemm(args.channels, args.height, args.width, + auto status = Convgemm(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -189,10 +191,16 @@ StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) const auto input_value = buffers_host.a_mat[input_index + args.a_offset]; // Multiplies with the kernel tensor - const auto kernel_index = kw_id + args.kernel_w * ( - kh_id + args.kernel_h * ( - ci_id + args.channels * ( - co_id))); + const auto kernel_index + = (args.kernel_mode == KernelMode::kConvolution) + ? (args.kernel_w - kw_id - 1) + args.kernel_w * ( + (args.kernel_h - kh_id - 1) + args.kernel_h * ( + ci_id + args.channels * ( + co_id))) + : kw_id + args.kernel_w * ( + kh_id + args.kernel_h * ( + ci_id + args.channels * ( + co_id))); const auto kernel_value = buffers_host.b_mat[kernel_index + args.b_offset]; result += input_value * kernel_value; diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp index acf7998b..2a3577c3 100644 --- a/test/routines/levelx/xim2col.hpp +++ b/test/routines/levelx/xim2col.hpp @@ -31,7 +31,8 @@ public: // The list of arguments relevant for this routine static std::vector GetOptions() { - return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, + return {kArgKernelMode, + kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, kArgAOffset, kArgBOffset}; } @@ -87,7 +88,8 @@ public: #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; - auto status = Im2col(args.channels, args.height, args.width, + auto status = Im2col(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -97,7 +99,8 @@ public: &queue_plain, &event); if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API - auto status = Im2col(args.channels, args.height, args.width, + auto status = Im2col(args.kernel_mode, + args.channels, args.height, args.width, args.kernel_h, args.kernel_w, args.pad_h, args.pad_w, args.stride_h, args.stride_w, @@ -175,7 +178,10 @@ StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) } // Sets the output value - const auto kernel_index = kw_id + args.kernel_w * kh_id; + const auto kernel_index + = (args.kernel_mode == KernelMode::kConvolution) + ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1 + : kw_id + args.kernel_w * kh_id; const auto patch_index = w_id + col_w * h_id; const auto col_index = patch_index + kernel_index * col_w * col_h + c_id * col_w * col_h * args.kernel_h * args.kernel_w; -- cgit v1.2.3