summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/api.md51
-rw-r--r--include/clblast.h10
-rw-r--r--include/clblast_c.h40
-rw-r--r--include/clblast_cuda.h10
-rw-r--r--include/clblast_netlib_c.h25
-rwxr-xr-xscripts/generator/generator.py8
-rw-r--r--scripts/generator/generator/convert.py2
-rw-r--r--src/clblast.cpp57
-rw-r--r--src/clblast_c.cpp78
-rw-r--r--src/clblast_cuda.cpp57
-rw-r--r--src/clblast_netlib_c.cpp48
-rw-r--r--src/kernels/levelx/col2im.opencl73
-rw-r--r--src/kernels/levelx/im2col.opencl59
-rw-r--r--src/routines/levelx/xcol2im.cpp8
-rw-r--r--src/routines/levelx/xcol2im.hpp3
-rw-r--r--src/routines/levelx/xconvgemm.cpp6
-rw-r--r--src/routines/levelx/xconvgemm.hpp3
-rw-r--r--src/routines/levelx/xim2col.cpp14
-rw-r--r--src/routines/levelx/xim2col.hpp3
-rw-r--r--src/utilities/utilities.cpp8
-rw-r--r--src/utilities/utilities.hpp2
-rw-r--r--test/correctness/misc/override_parameters.cpp1
-rw-r--r--test/correctness/testblas.hpp38
-rw-r--r--test/routines/levelx/xcol2im.hpp14
-rw-r--r--test/routines/levelx/xconvgemm.hpp20
-rw-r--r--test/routines/levelx/xim2col.hpp14
26 files changed, 461 insertions, 191 deletions
diff --git a/doc/api.md b/doc/api.md
index 337b5af9..996505f1 100644
--- a/doc/api.md
+++ b/doc/api.md
@@ -3020,7 +3020,8 @@ Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is th
C++ API:
```
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event)
@@ -3028,23 +3029,28 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
C API:
```
-CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event)
@@ -3052,6 +3058,7 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con
Arguments to IM2COL:
+* `const KernelMode kernel_mode`: The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes.
* `const size_t channels`: Integer size argument. This value must be positive.
* `const size_t height`: Integer size argument. This value must be positive.
* `const size_t width`: Integer size argument. This value must be positive.
@@ -3080,7 +3087,8 @@ Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is th
C++ API:
```
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event)
@@ -3088,23 +3096,28 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width
C API:
```
-CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event)
@@ -3112,6 +3125,7 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con
Arguments to COL2IM:
+* `const KernelMode kernel_mode`: The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes.
* `const size_t channels`: Integer size argument. This value must be positive.
* `const size_t height`: Integer size argument. This value must be positive.
* `const size_t width`: Integer size argument. This value must be positive.
@@ -3140,7 +3154,8 @@ Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D i
C++ API:
```
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
@@ -3149,17 +3164,20 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid
C API:
```
-CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event)
-CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
@@ -3168,6 +3186,7 @@ CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, c
Arguments to CONVGEMM:
+* `const KernelMode kernel_mode`: The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes.
* `const size_t channels`: Integer size argument. This value must be positive.
* `const size_t height`: Integer size argument. This value must be positive.
* `const size_t width`: Integer size argument. This value must be positive.
diff --git a/include/clblast.h b/include/clblast.h
index 27adf7fa..7a82361c 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -117,6 +117,7 @@ enum class Transpose { kNo = 111, kYes = 112, kConjugate = 113 };
enum class Triangle { kUpper = 121, kLower = 122 };
enum class Diagonal { kNonUnit = 131, kUnit = 132 };
enum class Side { kLeft = 141, kRight = 142 };
+enum class KernelMode { kCrossCorrelation = 151, kConvolution = 152 };
// Precision scoped enum (values in bits)
enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64,
@@ -631,21 +632,24 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event = nullptr);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event = nullptr);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
diff --git a/include/clblast_c.h b/include/clblast_c.h
index 1c681bfe..2ba6375a 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -120,6 +120,7 @@ typedef enum CLBlastTriangle_ { CLBlastTriangleUpper = 121,
typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131,
CLBlastDiagonalUnit = 132 } CLBlastDiagonal;
typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide;
+typedef enum CLBlastKernelMode_ { CLBlastKernelModeCrossCorrelation = 151, CLBlastKernelModeConvolution = 152 } CLBlastKernelMode;
// Precision enum (values in bits)
typedef enum CLBlastPrecision_ { CLBlastPrecisionHalf = 16, CLBlastPrecisionSingle = 32,
@@ -1389,61 +1390,74 @@ CLBlastStatusCode PUBLIC_API CLBlastHomatcopy(const CLBlastLayout layout, const
cl_command_queue* queue, cl_event* event);
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
-CLBlastStatusCode PUBLIC_API CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastSim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastDim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastCim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastZim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastHim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
-CLBlastStatusCode PUBLIC_API CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastScol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastDcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastCcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastZcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode PUBLIC_API CLBlastHcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
-CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode PUBLIC_API CLBlastDconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event);
-CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h
index 58f9b74b..f6d6372d 100644
--- a/include/clblast_cuda.h
+++ b/include/clblast_cuda.h
@@ -89,6 +89,7 @@ enum class Transpose { kNo = 111, kYes = 112, kConjugate = 113 };
enum class Triangle { kUpper = 121, kLower = 122 };
enum class Diagonal { kNonUnit = 131, kUnit = 132 };
enum class Side { kLeft = 141, kRight = 142 };
+enum class KernelMode { kCrossCorrelation = 151, kConvolution = 152 };
// Precision scoped enum (values in bits)
enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64,
@@ -603,21 +604,24 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr im_buffer, const size_t im_offset,
CUdeviceptr col_buffer, const size_t col_offset,
const CUcontext context, const CUdevice device);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr col_buffer, const size_t col_offset,
CUdeviceptr im_buffer, const size_t im_offset,
const CUcontext context, const CUdevice device);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const CUdeviceptr im_buffer, const size_t im_offset,
const CUdeviceptr kernel_buffer, const size_t kernel_offset,
CUdeviceptr result_buffer, const size_t result_offset,
diff --git a/include/clblast_netlib_c.h b/include/clblast_netlib_c.h
index 65545bfb..4c54fb18 100644
--- a/include/clblast_netlib_c.h
+++ b/include/clblast_netlib_c.h
@@ -45,6 +45,7 @@ typedef enum CLBlastTriangle_ { CLBlastTriangleUpper = 121,
typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131,
CLBlastDiagonalUnit = 132 } CLBlastDiagonal;
typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide;
+typedef enum CLBlastKernelMode_ { CLBlastKernelModeCrossCorrelation = 141, CLBlastKernelModeConvolution = 152 } CLBlastKernelMode;
// For full compatibility with CBLAS
typedef CLBlastLayout CBLAS_ORDER;
@@ -947,30 +948,38 @@ void PUBLIC_API cblas_zomatcopy(const CLBlastLayout layout, const CLBlastTranspo
void* b, const int b_ld);
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
-void PUBLIC_API cblas_sim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_sim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const float* im,
float* col);
-void PUBLIC_API cblas_dim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_dim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const double* im,
double* col);
-void PUBLIC_API cblas_cim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_cim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* im,
void* col);
-void PUBLIC_API cblas_zim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_zim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* im,
void* col);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
-void PUBLIC_API cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_scol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const float* col,
float* im);
-void PUBLIC_API cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_dcol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const double* col,
double* im);
-void PUBLIC_API cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_ccol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* col,
void* im);
-void PUBLIC_API cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void PUBLIC_API cblas_zcol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* col,
void* im);
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index f8022d81..68e3f01a 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -49,7 +49,7 @@ FILES = [
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
-HEADER_LINES = [123, 21, 127, 24, 29, 45, 29, 65, 40, 95, 21, 290]
+HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 290]
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 232
@@ -180,9 +180,9 @@ ROUTINES = [
# Special routines:
Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []),
Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
- Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix. Overwrites any existing values in the _col_ buffer", []),
- Routine(True, True, 0, False, "x", "col2im", T, [S,D,C,Z,H], im2col_constants, [], ["col"], ["im"], [col,im], [""], "", "Col2im function (non-BLAS function)", "Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix. Accumulates results on top of the existing values in the _im_ buffer.", []),
- Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
+ Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, ["kernel_mode"], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix. Overwrites any existing values in the _col_ buffer", []),
+ Routine(True, True, 0, False, "x", "col2im", T, [S,D,C,Z,H], im2col_constants, ["kernel_mode"], ["col"], ["im"], [col,im], [""], "", "Col2im function (non-BLAS function)", "Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix. Accumulates results on top of the existing values in the _im_ buffer.", []),
+ Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, ["kernel_mode"], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
# Batched routines:
Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
index 07f45669..16890d27 100644
--- a/scripts/generator/generator/convert.py
+++ b/scripts/generator/generator/convert.py
@@ -27,6 +27,7 @@ def option_to_clblast(x):
'side': "Side",
'triangle': "Triangle",
'diagonal': "Diagonal",
+ 'kernel_mode': "KernelMode",
}[x]
@@ -79,4 +80,5 @@ def option_to_documentation(x):
'side': "The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).",
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
+ 'kernel_mode': "The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes.",
}[x]
diff --git a/src/clblast.cpp b/src/clblast.cpp
index e45f504a..180693e7 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2218,79 +2218,94 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xim2col<T>(queue_cpp, event);
- routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(col_buffer), col_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xcol2im<T>(queue_cpp, event);
- routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoCol2im(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(col_buffer), col_offset,
Buffer<T>(im_buffer), im_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
@@ -2298,24 +2313,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid
try {
auto queue_cpp = Queue(*queue);
auto routine = Xconvgemm<T>(queue_cpp, event);
- routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ routine.DoConvgemm(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(kernel_buffer), kernel_offset,
Buffer<T>(result_buffer), result_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 645a69b1..a224230a 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3613,65 +3613,75 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran
}
// IM2COL
-CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
@@ -3680,65 +3690,75 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con
}
// COL2IM
-CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
@@ -3747,14 +3767,16 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con
}
// CONVGEMM
-CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
@@ -3762,14 +3784,16 @@ CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
@@ -3777,14 +3801,16 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 03d995ba..264f360d 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2314,7 +2314,8 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr im_buffer, const size_t im_offset,
CUdeviceptr col_buffer, const size_t col_offset,
const CUcontext context, const CUdevice device) {
@@ -2323,36 +2324,43 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xim2col<T>(queue_cpp, nullptr);
- routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(col_buffer), col_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr col_buffer, const size_t col_offset,
CUdeviceptr im_buffer, const size_t im_offset,
const CUcontext context, const CUdevice device) {
@@ -2361,36 +2369,43 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xcol2im<T>(queue_cpp, nullptr);
- routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoCol2im(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(col_buffer), col_offset,
Buffer<T>(im_buffer), im_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const CUdeviceptr im_buffer, const size_t im_offset,
const CUdeviceptr kernel_buffer, const size_t kernel_offset,
CUdeviceptr result_buffer, const size_t result_offset,
@@ -2400,24 +2415,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xconvgemm<T>(queue_cpp, nullptr);
- routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ routine.DoConvgemm(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(kernel_buffer), kernel_offset,
Buffer<T>(result_buffer), result_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
diff --git a/src/clblast_netlib_c.cpp b/src/clblast_netlib_c.cpp
index 22570535..3a8f729e 100644
--- a/src/clblast_netlib_c.cpp
+++ b/src/clblast_netlib_c.cpp
@@ -4878,7 +4878,8 @@ void cblas_zomatcopy(const CLBlastLayout layout, const CLBlastTranspose a_transp
}
// IM2COL
-void cblas_sim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_sim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const float* im,
float* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4891,7 +4892,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const float*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<float*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4900,7 +4902,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const
}
col_buffer.Read(queue, col_size, reinterpret_cast<float*>(col));
}
-void cblas_dim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_dim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const double* im,
double* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4913,7 +4916,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const double*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<double*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4922,7 +4926,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const
}
col_buffer.Read(queue, col_size, reinterpret_cast<double*>(col));
}
-void cblas_cim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_cim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* im,
void* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4935,7 +4940,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const float2*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<float2*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4944,7 +4950,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const
}
col_buffer.Read(queue, col_size, reinterpret_cast<float2*>(col));
}
-void cblas_zim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_zim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* im,
void* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4957,7 +4964,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const double2*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<double2*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4968,7 +4976,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const
}
// COL2IM
-void cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_scol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const float* col,
float* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -4981,7 +4990,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const float*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<float*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
@@ -4990,7 +5000,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const
}
im_buffer.Read(queue, im_size, reinterpret_cast<float*>(im));
}
-void cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_dcol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const double* col,
double* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -5003,7 +5014,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const double*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<double*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
@@ -5012,7 +5024,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const
}
im_buffer.Read(queue, im_size, reinterpret_cast<double*>(im));
}
-void cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_ccol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* col,
void* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -5025,7 +5038,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const float2*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<float2*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
@@ -5034,7 +5048,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const
}
im_buffer.Read(queue, im_size, reinterpret_cast<float2*>(im));
}
-void cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_zcol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* col,
void* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -5047,7 +5062,8 @@ void cblas_zcol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const double2*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<double2*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl
index a37db24f..484a7a98 100644
--- a/src/kernels/levelx/col2im.opencl
+++ b/src/kernels/levelx/col2im.opencl
@@ -28,18 +28,20 @@ inline int grid_ceil(const int x, const int step) {
return x > 0 ? ((x - 1) / step + 1) * step : x / step * step;
}
+// Main body of the kernel
__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void col2im(const int input_h, const int input_w, const int channels,
- const int output_h, const int output_w,
- const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- const int stride_bez_h, const int stride_bez_w,
- const int dilation_bez_h, const int dilation_bez_w,
- const int gcd_h, const int gcd_w,
- const __global real* restrict col_buffer, const int col_offset,
- __global real* im_buffer, const int im_offset) {
+INLINE_FUNC void Xcol2im(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const bool kernel_flip,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
const int input_h_scaled = (input_h - 1) / gcd_h + 1;
@@ -71,8 +73,9 @@ void col2im(const int input_h, const int input_w, const int channels,
const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w;
const int h_id = th / stride_h + stride_bez_h * gcd_scale_h;
const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w;
-
- const int kernel_index = kw_id + kernel_w * kh_id;
+ const int kernel_index = (kernel_flip)
+ ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1
+ : kw_id + kernel_w * kh_id;
const int patch_index = w_id + output_w * h_id;
const int output_index = patch_index + kernel_index * output_w * output_h +
c_id * output_w * output_h * kernel_h * kernel_w;
@@ -89,6 +92,50 @@ void col2im(const int input_h, const int input_w, const int channels,
// =================================================================================================
+// Kernel flip version of the Xcol2im kernel (for convolution)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xcol2imKernelFlip(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
+ const bool kernel_flip = true;
+ Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w,
+ kernel_flip,
+ col_buffer, col_offset, im_buffer, im_offset);
+}
+
+// Normal version of the Xcol2im kernel (for cross-correlation)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xcol2imKernelNormal(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
+ const bool kernel_flip = false;
+ Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w,
+ kernel_flip,
+ col_buffer, col_offset, im_buffer, im_offset);
+}
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl
index 301e076b..5db4cb5f 100644
--- a/src/kernels/levelx/im2col.opencl
+++ b/src/kernels/levelx/im2col.opencl
@@ -25,15 +25,16 @@ R"(
// =================================================================================================
-__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void im2col(const int input_h, const int input_w, const int channels,
- const int output_h, const int output_w,
- const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- const __global real* restrict im_buffer, const int im_offset,
- __global real* col_buffer, const int col_offset) {
+// Main body of the kernel
+INLINE_FUNC void Xim2col(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const bool kernel_flip,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
// Thread IDs
const int w_id = get_global_id(0); // image width, max 'output_w'
@@ -58,7 +59,9 @@ void im2col(const int input_h, const int input_w, const int channels,
}
// Sets the output value
- const int kernel_index = kw_id + kernel_w * kh_id;
+ const int kernel_index = (kernel_flip)
+ ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1
+ : kw_id + kernel_w * kh_id;
const int patch_index = w_id + output_w * h_id;
const int output_index = patch_index + kernel_index * output_w * output_h +
c_id * output_w * output_h * kernel_h * kernel_w;
@@ -70,6 +73,42 @@ void im2col(const int input_h, const int input_w, const int channels,
// =================================================================================================
+// Kernel flip version of the Xim2col kernel (for convolution)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xim2colKernelFlip(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
+ const bool kernel_flip = true;
+ Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ kernel_flip,
+ im_buffer, im_offset, col_buffer, col_offset);
+}
+
+// Normal version of the Xim2col kernel (for cross-correlation)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xim2colKernelNormal(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
+ const bool kernel_flip = false;
+ Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ kernel_flip,
+ im_buffer, im_offset, col_buffer, col_offset);
+}
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)"
diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp
index 7a0c36b7..d285e5c0 100644
--- a/src/routines/levelx/xcol2im.cpp
+++ b/src/routines/levelx/xcol2im.cpp
@@ -31,13 +31,17 @@ Xcol2im<T>::Xcol2im(Queue &queue, EventPointer event, const std::string &name):
// The main routine
template <typename T>
-void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size_t width,
+void Xcol2im<T>::DoCol2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
const Buffer<T> &col_buffer, const size_t col_offset,
const Buffer<T> &im_buffer, const size_t im_offset) {
+ // Flip the output along kernel_h and kernel_w, or not.
+ const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xcol2imKernelFlip" : "Xcol2imKernelNormal";
+
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
@@ -59,7 +63,7 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w);
// Retrieves the kernel from the compiled binary
- auto kernel = Kernel(program_, "col2im");
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
diff --git a/src/routines/levelx/xcol2im.hpp b/src/routines/levelx/xcol2im.hpp
index 86d68c45..522c717e 100644
--- a/src/routines/levelx/xcol2im.hpp
+++ b/src/routines/levelx/xcol2im.hpp
@@ -29,7 +29,8 @@ class Xcol2im: public Routine {
Xcol2im(Queue &queue, EventPointer event, const std::string &name = "COL2IM");
// Templated-precision implementation of the routine
- void DoCol2im(const size_t channels, const size_t height, const size_t width,
+ void DoCol2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w,
const size_t pad_h, const size_t pad_w,
const size_t stride_h, const size_t stride_w,
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index f26f23a7..88127b0f 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -43,7 +43,8 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam
// =================================================================================================
template <typename T>
-void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const size_t width,
+void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
@@ -94,7 +95,8 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const auto col_batch_offset = batch_id * patch_size * num_patches;
auto im2col_event = Event();
auto im2col = Xim2col<T>(queue_, im2col_event.pointer());
- im2col.DoIm2col(channels, height, width, kernel_h, kernel_w,
+ im2col.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_batch_offset,
col_buffer, col_batch_offset);
diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp
index 9d11ccee..20cfff60 100644
--- a/src/routines/levelx/xconvgemm.hpp
+++ b/src/routines/levelx/xconvgemm.hpp
@@ -32,7 +32,8 @@ class Xconvgemm: public Routine {
const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col);
// Templated-precision implementation of the routine
- void DoConvgemm(const size_t channels, const size_t height, const size_t width,
+ void DoConvgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w,
const size_t pad_h, const size_t pad_w,
const size_t stride_h, const size_t stride_w,
diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp
index 09dcc42c..0f786974 100644
--- a/src/routines/levelx/xim2col.cpp
+++ b/src/routines/levelx/xim2col.cpp
@@ -22,22 +22,26 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xim2col<T>::Xim2col(Queue &queue, EventPointer event, const std::string &name):
- Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
-#include "../../kernels/levelx/im2col.opencl"
- }) {
+ Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
+ #include "../../kernels/levelx/im2col.opencl"
+ }) {
}
// =================================================================================================
// The main routine
template <typename T>
-void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size_t width,
+void Xim2col<T>::DoIm2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
const Buffer<T> &im_buffer, const size_t im_offset,
const Buffer<T> &col_buffer, const size_t col_offset) {
+ // Flip the output along kernel_h and kernel_w, or not.
+ const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xim2colKernelFlip" : "Xim2colKernelNormal";
+
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
@@ -50,7 +54,7 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size
const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
// Retrieves the kernel from the compiled binary
- auto kernel = Kernel(program_, "im2col");
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
diff --git a/src/routines/levelx/xim2col.hpp b/src/routines/levelx/xim2col.hpp
index 2c03b169..77cc32eb 100644
--- a/src/routines/levelx/xim2col.hpp
+++ b/src/routines/levelx/xim2col.hpp
@@ -29,7 +29,8 @@ class Xim2col: public Routine {
Xim2col(Queue &queue, EventPointer event, const std::string &name = "IM2COL");
// Templated-precision implementation of the routine
- void DoIm2col(const size_t channels, const size_t height, const size_t width,
+ void DoIm2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w,
const size_t pad_h, const size_t pad_w,
const size_t stride_h, const size_t stride_w,
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index a6cd82e7..a0e89c98 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -175,6 +175,13 @@ std::string ToString(Precision value) {
}
}
template <>
+std::string ToString(KernelMode value) {
+ switch(value) {
+ case KernelMode::kCrossCorrelation: return ToString(static_cast<int>(value))+" (cross-correlation)";
+ case KernelMode::kConvolution: return ToString(static_cast<int>(value))+" (convolution)";
+ }
+}
+template <>
std::string ToString(StatusCode value) {
return std::to_string(static_cast<int>(value));
}
@@ -281,6 +288,7 @@ template Side GetArgument<Side>(const std::vector<std::string>&, std::string&, c
template Triangle GetArgument<Triangle>(const std::vector<std::string>&, std::string&, const std::string&, const Triangle);
template Diagonal GetArgument<Diagonal>(const std::vector<std::string>&, std::string&, const std::string&, const Diagonal);
template Precision GetArgument<Precision>(const std::vector<std::string>&, std::string&, const std::string&, const Precision);
+template KernelMode GetArgument<KernelMode>(const std::vector<std::string>&, std::string&, const std::string&, const KernelMode);
// =================================================================================================
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index fcc1c57f..23486d35 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -69,6 +69,7 @@ constexpr auto kArgBTransp = "transB";
constexpr auto kArgSide = "side";
constexpr auto kArgTriangle = "triangle";
constexpr auto kArgDiagonal = "diagonal";
+constexpr auto kArgKernelMode = "kernel_mode";
constexpr auto kArgXInc = "incx";
constexpr auto kArgYInc = "incy";
constexpr auto kArgXOffset = "offx";
@@ -183,6 +184,7 @@ struct Arguments {
Side side = Side::kLeft;
Triangle triangle = Triangle::kUpper;
Diagonal diagonal = Diagonal::kUnit;
+ KernelMode kernel_mode = KernelMode::kCrossCorrelation;
size_t x_inc = 1;
size_t y_inc = 1;
size_t x_offset = 0;
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
index 54229c5e..7ed4faff 100644
--- a/test/correctness/misc/override_parameters.cpp
+++ b/test/correctness/misc/override_parameters.cpp
@@ -60,6 +60,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
args.layout = GetArgument(arguments, help, kArgLayout, Layout::kRowMajor);
args.a_transpose = GetArgument(arguments, help, kArgATransp, Transpose::kNo);
args.b_transpose = GetArgument(arguments, help, kArgBTransp, Transpose::kNo);
+ args.kernel_mode = GetArgument(arguments, help, kArgKernelMode, KernelMode::kCrossCorrelation);
args.alpha = GetArgument(arguments, help, kArgAlpha, GetScalar<T>());
args.beta = GetArgument(arguments, help, kArgBeta, GetScalar<T>());
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index 137df30f..b2dc6e7a 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -63,6 +63,7 @@ class TestBlas: public Tester<T,U> {
static const std::vector<size_t> kNumKernels;
static const std::vector<size_t> kStrideValues;
static const std::vector<size_t> kChannelValues;
+ static const std::vector<KernelMode> kKernelModes;
const std::vector<size_t> kOffsets;
const std::vector<U> kAlphaValues;
const std::vector<U> kBetaValues;
@@ -142,6 +143,7 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKern
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kNumKernels = { 1, 6 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kStrideValues = { 1, 3 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kChannelValues = { 1, 2 };
+template <typename T, typename U> const std::vector<KernelMode> TestBlas<T,U>::kKernelModes = { KernelMode::kCrossCorrelation, KernelMode::kConvolution };
// Test settings for the invalid tests
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kInvalidIncrements = { 0, 1 };
@@ -168,6 +170,7 @@ static StatusCode ReferenceNotAvailable(const Arguments<U> &, BufferType &, Queu
template <typename C, typename T, typename U>
void handle_remaining_of_options(std::vector<Arguments<U>> &regular_test_vector, Arguments<U> &r_args,
TestBlas<T,U> &tester,
+ const std::vector<KernelMode> &kernel_modes,
const std::vector<size_t> &channelss,
const std::vector<size_t> &heights,
const std::vector<size_t> &widths,
@@ -181,21 +184,23 @@ void handle_remaining_of_options(std::vector<Arguments<U>> &regular_test_vector,
const std::vector<size_t> &dilation_ws,
const std::vector<size_t> &batch_counts,
const std::vector<size_t> &num_kernelss) {
- for (auto &channels: channelss) { r_args.channels = channels;
- for (auto &height: heights) { r_args.height = height;
- for (auto &width: widths) { r_args.width = width;
- for (auto &kernel_h: kernel_hs) { r_args.kernel_h = kernel_h;
- for (auto &kernel_w: kernel_ws) { r_args.kernel_w = kernel_w;
- for (auto &pad_h: pad_hs) { r_args.pad_h = pad_h;
- for (auto &pad_w: pad_ws) { r_args.pad_w = pad_w;
- for (auto &stride_h: stride_hs) { r_args.stride_h = stride_h;
- for (auto &stride_w: stride_ws) { r_args.stride_w = stride_w;
- for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h;
- for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w;
- for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count;
- for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels;
- C::SetSizes(r_args, tester.queue_);
- regular_test_vector.push_back(r_args);
+ for (auto &kernel_mode: kernel_modes) { r_args.kernel_mode = kernel_mode;
+ for (auto &channels: channelss) { r_args.channels = channels;
+ for (auto &height: heights) { r_args.height = height;
+ for (auto &width: widths) { r_args.width = width;
+ for (auto &kernel_h: kernel_hs) { r_args.kernel_h = kernel_h;
+ for (auto &kernel_w: kernel_ws) { r_args.kernel_w = kernel_w;
+ for (auto &pad_h: pad_hs) { r_args.pad_h = pad_h;
+ for (auto &pad_w: pad_ws) { r_args.pad_w = pad_w;
+ for (auto &stride_h: stride_hs) { r_args.stride_h = stride_h;
+ for (auto &stride_w: stride_ws) { r_args.stride_w = stride_w;
+ for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h;
+ for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w;
+ for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count;
+ for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels;
+ C::SetSizes(r_args, tester.queue_);
+ regular_test_vector.push_back(r_args);
+ }
}
}
}
@@ -284,6 +289,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
auto imax_offsets = std::vector<size_t>{args.imax_offset};
auto alphas = std::vector<U>{args.alpha};
auto betas = std::vector<U>{args.beta};
+ auto kernel_modes = std::vector<KernelMode>{args.kernel_mode};
auto channelss = std::vector<size_t>{args.channels};
auto heights = std::vector<size_t>{args.height};
auto widths = std::vector<size_t>{args.width};
@@ -340,6 +346,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; }
if (option == kArgAlpha) { alphas = tester.kAlphaValues; }
if (option == kArgBeta) { betas = tester.kBetaValues; }
+ if (option == kArgKernelMode) { kernel_modes = tester.kKernelModes; }
if (option == kArgChannels) { channelss = tester.kChannelValues; }
if (option == kArgHeight) { heights = tester.kMatrixDims; }
if (option == kArgWidth) { widths = tester.kMatrixDims; }
@@ -397,6 +404,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
for (auto &beta: betas) { r_args.beta = beta;
// Cannot have more for-loops because of MSVC's C1061 error
handle_remaining_of_options<C>(regular_test_vector, r_args, tester,
+ kernel_modes,
channelss, heights, widths, kernel_hs, kernel_ws,
pad_hs, pad_ws, stride_hs, stride_ws,
dilation_hs, dilation_ws,
diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp
index 176fceae..c28727e7 100644
--- a/test/routines/levelx/xcol2im.hpp
+++ b/test/routines/levelx/xcol2im.hpp
@@ -31,7 +31,8 @@ public:
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
- return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
+ return {kArgKernelMode,
+ kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW,
kArgAOffset, kArgBOffset};
}
@@ -87,7 +88,8 @@ public:
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
- auto status = Col2im<T>(args.channels, args.height, args.width,
+ auto status = Col2im<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -97,7 +99,8 @@ public:
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
- auto status = Col2im<T>(args.channels, args.height, args.width,
+ auto status = Col2im<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -167,7 +170,10 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
for (auto w_id = size_t{0}; w_id < col_w; ++w_id) { // image width
// Reads the input value
- const auto kernel_index = kw_id + args.kernel_w * kh_id;
+ const auto kernel_index
+ = (args.kernel_mode == KernelMode::kConvolution)
+ ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1
+ : kw_id + args.kernel_w * kh_id;
const auto patch_index = w_id + col_w * h_id;
const auto col_index = patch_index + kernel_index * col_w * col_h +
c_id * col_w * col_h * args.kernel_h * args.kernel_w;
diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp
index 7fa4e701..e67b8174 100644
--- a/test/routines/levelx/xconvgemm.hpp
+++ b/test/routines/levelx/xconvgemm.hpp
@@ -91,7 +91,8 @@ public:
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
- auto status = Convgemm<T>(args.channels, args.height, args.width,
+ auto status = Convgemm<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -103,7 +104,8 @@ public:
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
- auto status = Convgemm<T>(args.channels, args.height, args.width,
+ auto status = Convgemm<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -189,10 +191,16 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
const auto input_value = buffers_host.a_mat[input_index + args.a_offset];
// Multiplies with the kernel tensor
- const auto kernel_index = kw_id + args.kernel_w * (
- kh_id + args.kernel_h * (
- ci_id + args.channels * (
- co_id)));
+ const auto kernel_index
+ = (args.kernel_mode == KernelMode::kConvolution)
+ ? (args.kernel_w - kw_id - 1) + args.kernel_w * (
+ (args.kernel_h - kh_id - 1) + args.kernel_h * (
+ ci_id + args.channels * (
+ co_id)))
+ : kw_id + args.kernel_w * (
+ kh_id + args.kernel_h * (
+ ci_id + args.channels * (
+ co_id)));
const auto kernel_value = buffers_host.b_mat[kernel_index + args.b_offset];
result += input_value * kernel_value;
diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp
index acf7998b..2a3577c3 100644
--- a/test/routines/levelx/xim2col.hpp
+++ b/test/routines/levelx/xim2col.hpp
@@ -31,7 +31,8 @@ public:
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
- return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
+ return {kArgKernelMode,
+ kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW,
kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW,
kArgAOffset, kArgBOffset};
}
@@ -87,7 +88,8 @@ public:
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
- auto status = Im2col<T>(args.channels, args.height, args.width,
+ auto status = Im2col<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -97,7 +99,8 @@ public:
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
- auto status = Im2col<T>(args.channels, args.height, args.width,
+ auto status = Im2col<T>(args.kernel_mode,
+ args.channels, args.height, args.width,
args.kernel_h, args.kernel_w,
args.pad_h, args.pad_w,
args.stride_h, args.stride_w,
@@ -175,7 +178,10 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
}
// Sets the output value
- const auto kernel_index = kw_id + args.kernel_w * kh_id;
+ const auto kernel_index
+ = (args.kernel_mode == KernelMode::kConvolution)
+ ? args.kernel_h * args.kernel_w - kw_id - args.kernel_w * kh_id - 1
+ : kw_id + args.kernel_w * kh_id;
const auto patch_index = w_id + col_w * h_id;
const auto col_index = patch_index + kernel_index * col_w * col_h +
c_id * col_w * col_h * args.kernel_h * args.kernel_w;