summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cpp57
-rw-r--r--src/clblast_c.cpp78
-rw-r--r--src/clblast_cuda.cpp57
-rw-r--r--src/clblast_netlib_c.cpp48
-rw-r--r--src/kernels/levelx/col2im.opencl73
-rw-r--r--src/kernels/levelx/im2col.opencl59
-rw-r--r--src/routines/levelx/xcol2im.cpp8
-rw-r--r--src/routines/levelx/xcol2im.hpp3
-rw-r--r--src/routines/levelx/xconvgemm.cpp6
-rw-r--r--src/routines/levelx/xconvgemm.hpp3
-rw-r--r--src/routines/levelx/xim2col.cpp14
-rw-r--r--src/routines/levelx/xim2col.hpp3
-rw-r--r--src/utilities/utilities.cpp8
-rw-r--r--src/utilities/utilities.hpp2
14 files changed, 304 insertions, 115 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index e45f504a..180693e7 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2218,79 +2218,94 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xim2col<T>(queue_cpp, event);
- routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(col_buffer), col_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xcol2im<T>(queue_cpp, event);
- routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoCol2im(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(col_buffer), col_offset,
Buffer<T>(im_buffer), im_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
@@ -2298,24 +2313,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid
try {
auto queue_cpp = Queue(*queue);
auto routine = Xconvgemm<T>(queue_cpp, event);
- routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ routine.DoConvgemm(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(kernel_buffer), kernel_offset,
Buffer<T>(result_buffer), result_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 645a69b1..a224230a 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3613,65 +3613,75 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran
}
// IM2COL
-CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
@@ -3680,65 +3690,75 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con
}
// COL2IM
-CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
@@ -3747,14 +3767,16 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con
}
// CONVGEMM
-CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
@@ -3762,14 +3784,16 @@ CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
@@ -3777,14 +3801,16 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 03d995ba..264f360d 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2314,7 +2314,8 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr im_buffer, const size_t im_offset,
CUdeviceptr col_buffer, const size_t col_offset,
const CUcontext context, const CUdevice device) {
@@ -2323,36 +2324,43 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xim2col<T>(queue_cpp, nullptr);
- routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(col_buffer), col_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr col_buffer, const size_t col_offset,
CUdeviceptr im_buffer, const size_t im_offset,
const CUcontext context, const CUdevice device) {
@@ -2361,36 +2369,43 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xcol2im<T>(queue_cpp, nullptr);
- routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoCol2im(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(col_buffer), col_offset,
Buffer<T>(im_buffer), im_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const CUdeviceptr im_buffer, const size_t im_offset,
const CUdeviceptr kernel_buffer, const size_t kernel_offset,
CUdeviceptr result_buffer, const size_t result_offset,
@@ -2400,24 +2415,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xconvgemm<T>(queue_cpp, nullptr);
- routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ routine.DoConvgemm(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(kernel_buffer), kernel_offset,
Buffer<T>(result_buffer), result_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
-template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
diff --git a/src/clblast_netlib_c.cpp b/src/clblast_netlib_c.cpp
index 22570535..3a8f729e 100644
--- a/src/clblast_netlib_c.cpp
+++ b/src/clblast_netlib_c.cpp
@@ -4878,7 +4878,8 @@ void cblas_zomatcopy(const CLBlastLayout layout, const CLBlastTranspose a_transp
}
// IM2COL
-void cblas_sim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_sim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const float* im,
float* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4891,7 +4892,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const float*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<float*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4900,7 +4902,8 @@ void cblas_sim2col(const int channels, const int height, const int width, const
}
col_buffer.Read(queue, col_size, reinterpret_cast<float*>(col));
}
-void cblas_dim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_dim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const double* im,
double* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4913,7 +4916,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const double*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<double*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4922,7 +4926,8 @@ void cblas_dim2col(const int channels, const int height, const int width, const
}
col_buffer.Read(queue, col_size, reinterpret_cast<double*>(col));
}
-void cblas_cim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_cim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* im,
void* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4935,7 +4940,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const float2*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<float2*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4944,7 +4950,8 @@ void cblas_cim2col(const int channels, const int height, const int width, const
}
col_buffer.Read(queue, col_size, reinterpret_cast<float2*>(col));
}
-void cblas_zim2col(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_zim2col(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* im,
void* col) {
OPTIONAL_STATIC auto device = get_device();
@@ -4957,7 +4964,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const
im_buffer.Write(queue, im_size, reinterpret_cast<const double2*>(im));
col_buffer.Write(queue, col_size, reinterpret_cast<double2*>(col));
auto queue_cl = queue();
- auto s = clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer(), 0,
col_buffer(), 0,
&queue_cl);
@@ -4968,7 +4976,8 @@ void cblas_zim2col(const int channels, const int height, const int width, const
}
// COL2IM
-void cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_scol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const float* col,
float* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -4981,7 +4990,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const float*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<float*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
@@ -4990,7 +5000,8 @@ void cblas_scol2im(const int channels, const int height, const int width, const
}
im_buffer.Read(queue, im_size, reinterpret_cast<float*>(im));
}
-void cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_dcol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const double* col,
double* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -5003,7 +5014,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const double*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<double*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
@@ -5012,7 +5024,8 @@ void cblas_dcol2im(const int channels, const int height, const int width, const
}
im_buffer.Read(queue, im_size, reinterpret_cast<double*>(im));
}
-void cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_ccol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* col,
void* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -5025,7 +5038,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const float2*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<float2*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
@@ -5034,7 +5048,8 @@ void cblas_ccol2im(const int channels, const int height, const int width, const
}
im_buffer.Read(queue, im_size, reinterpret_cast<float2*>(im));
}
-void cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+void cblas_zcol2im(const CLBlastKernelMode kernel_mode,
+ const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
const void* col,
void* im) {
OPTIONAL_STATIC auto device = get_device();
@@ -5047,7 +5062,8 @@ void cblas_zcol2im(const int channels, const int height, const int width, const
col_buffer.Write(queue, col_size, reinterpret_cast<const double2*>(col));
im_buffer.Write(queue, im_size, reinterpret_cast<double2*>(im));
auto queue_cl = queue();
- auto s = clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ auto s = clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer(), 0,
im_buffer(), 0,
&queue_cl);
diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl
index a37db24f..484a7a98 100644
--- a/src/kernels/levelx/col2im.opencl
+++ b/src/kernels/levelx/col2im.opencl
@@ -28,18 +28,20 @@ inline int grid_ceil(const int x, const int step) {
return x > 0 ? ((x - 1) / step + 1) * step : x / step * step;
}
+// Main body of the kernel
__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void col2im(const int input_h, const int input_w, const int channels,
- const int output_h, const int output_w,
- const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- const int stride_bez_h, const int stride_bez_w,
- const int dilation_bez_h, const int dilation_bez_w,
- const int gcd_h, const int gcd_w,
- const __global real* restrict col_buffer, const int col_offset,
- __global real* im_buffer, const int im_offset) {
+INLINE_FUNC void Xcol2im(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const bool kernel_flip,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
const int input_h_scaled = (input_h - 1) / gcd_h + 1;
@@ -71,8 +73,9 @@ void col2im(const int input_h, const int input_w, const int channels,
const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w;
const int h_id = th / stride_h + stride_bez_h * gcd_scale_h;
const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w;
-
- const int kernel_index = kw_id + kernel_w * kh_id;
+ const int kernel_index = (kernel_flip)
+ ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1
+ : kw_id + kernel_w * kh_id;
const int patch_index = w_id + output_w * h_id;
const int output_index = patch_index + kernel_index * output_w * output_h +
c_id * output_w * output_h * kernel_h * kernel_w;
@@ -89,6 +92,50 @@ void col2im(const int input_h, const int input_w, const int channels,
// =================================================================================================
+// Kernel flip version of the Xcol2im kernel (for convolution)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xcol2imKernelFlip(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
+ const bool kernel_flip = true;
+ Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w,
+ kernel_flip,
+ col_buffer, col_offset, im_buffer, im_offset);
+}
+
+// Normal version of the Xcol2im kernel (for cross-correlation)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xcol2imKernelNormal(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
+ const bool kernel_flip = false;
+ Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w,
+ kernel_flip,
+ col_buffer, col_offset, im_buffer, im_offset);
+}
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl
index 301e076b..5db4cb5f 100644
--- a/src/kernels/levelx/im2col.opencl
+++ b/src/kernels/levelx/im2col.opencl
@@ -25,15 +25,16 @@ R"(
// =================================================================================================
-__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void im2col(const int input_h, const int input_w, const int channels,
- const int output_h, const int output_w,
- const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- const __global real* restrict im_buffer, const int im_offset,
- __global real* col_buffer, const int col_offset) {
+// Main body of the kernel
+INLINE_FUNC void Xim2col(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const bool kernel_flip,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
// Thread IDs
const int w_id = get_global_id(0); // image width, max 'output_w'
@@ -58,7 +59,9 @@ void im2col(const int input_h, const int input_w, const int channels,
}
// Sets the output value
- const int kernel_index = kw_id + kernel_w * kh_id;
+ const int kernel_index = (kernel_flip)
+ ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1
+ : kw_id + kernel_w * kh_id;
const int patch_index = w_id + output_w * h_id;
const int output_index = patch_index + kernel_index * output_w * output_h +
c_id * output_w * output_h * kernel_h * kernel_w;
@@ -70,6 +73,42 @@ void im2col(const int input_h, const int input_w, const int channels,
// =================================================================================================
+// Kernel flip version of the Xim2col kernel (for convolution)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xim2colKernelFlip(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
+ const bool kernel_flip = true;
+ Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ kernel_flip,
+ im_buffer, im_offset, col_buffer, col_offset);
+}
+
+// Normal version of the Xim2col kernel (for cross-correlation)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xim2colKernelNormal(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
+ const bool kernel_flip = false;
+ Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ kernel_flip,
+ im_buffer, im_offset, col_buffer, col_offset);
+}
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)"
diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp
index 7a0c36b7..d285e5c0 100644
--- a/src/routines/levelx/xcol2im.cpp
+++ b/src/routines/levelx/xcol2im.cpp
@@ -31,13 +31,17 @@ Xcol2im<T>::Xcol2im(Queue &queue, EventPointer event, const std::string &name):
// The main routine
template <typename T>
-void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size_t width,
+void Xcol2im<T>::DoCol2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
const Buffer<T> &col_buffer, const size_t col_offset,
const Buffer<T> &im_buffer, const size_t im_offset) {
+ // Flip the output along kernel_h and kernel_w, or not.
+ const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xcol2imKernelFlip" : "Xcol2imKernelNormal";
+
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
@@ -59,7 +63,7 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w);
// Retrieves the kernel from the compiled binary
- auto kernel = Kernel(program_, "col2im");
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
diff --git a/src/routines/levelx/xcol2im.hpp b/src/routines/levelx/xcol2im.hpp
index 86d68c45..522c717e 100644
--- a/src/routines/levelx/xcol2im.hpp
+++ b/src/routines/levelx/xcol2im.hpp
@@ -29,7 +29,8 @@ class Xcol2im: public Routine {
Xcol2im(Queue &queue, EventPointer event, const std::string &name = "COL2IM");
// Templated-precision implementation of the routine
- void DoCol2im(const size_t channels, const size_t height, const size_t width,
+ void DoCol2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w,
const size_t pad_h, const size_t pad_w,
const size_t stride_h, const size_t stride_w,
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index f26f23a7..88127b0f 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -43,7 +43,8 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam
// =================================================================================================
template <typename T>
-void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const size_t width,
+void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
@@ -94,7 +95,8 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const auto col_batch_offset = batch_id * patch_size * num_patches;
auto im2col_event = Event();
auto im2col = Xim2col<T>(queue_, im2col_event.pointer());
- im2col.DoIm2col(channels, height, width, kernel_h, kernel_w,
+ im2col.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_batch_offset,
col_buffer, col_batch_offset);
diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp
index 9d11ccee..20cfff60 100644
--- a/src/routines/levelx/xconvgemm.hpp
+++ b/src/routines/levelx/xconvgemm.hpp
@@ -32,7 +32,8 @@ class Xconvgemm: public Routine {
const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col);
// Templated-precision implementation of the routine
- void DoConvgemm(const size_t channels, const size_t height, const size_t width,
+ void DoConvgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w,
const size_t pad_h, const size_t pad_w,
const size_t stride_h, const size_t stride_w,
diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp
index 09dcc42c..0f786974 100644
--- a/src/routines/levelx/xim2col.cpp
+++ b/src/routines/levelx/xim2col.cpp
@@ -22,22 +22,26 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xim2col<T>::Xim2col(Queue &queue, EventPointer event, const std::string &name):
- Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
-#include "../../kernels/levelx/im2col.opencl"
- }) {
+ Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
+ #include "../../kernels/levelx/im2col.opencl"
+ }) {
}
// =================================================================================================
// The main routine
template <typename T>
-void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size_t width,
+void Xim2col<T>::DoIm2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
const Buffer<T> &im_buffer, const size_t im_offset,
const Buffer<T> &col_buffer, const size_t col_offset) {
+ // Flip the output along kernel_h and kernel_w, or not.
+ const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xim2colKernelFlip" : "Xim2colKernelNormal";
+
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
@@ -50,7 +54,7 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size
const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
// Retrieves the kernel from the compiled binary
- auto kernel = Kernel(program_, "im2col");
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
diff --git a/src/routines/levelx/xim2col.hpp b/src/routines/levelx/xim2col.hpp
index 2c03b169..77cc32eb 100644
--- a/src/routines/levelx/xim2col.hpp
+++ b/src/routines/levelx/xim2col.hpp
@@ -29,7 +29,8 @@ class Xim2col: public Routine {
Xim2col(Queue &queue, EventPointer event, const std::string &name = "IM2COL");
// Templated-precision implementation of the routine
- void DoIm2col(const size_t channels, const size_t height, const size_t width,
+ void DoIm2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w,
const size_t pad_h, const size_t pad_w,
const size_t stride_h, const size_t stride_w,
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index a6cd82e7..a0e89c98 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -175,6 +175,13 @@ std::string ToString(Precision value) {
}
}
template <>
+std::string ToString(KernelMode value) {
+ switch(value) {
+ case KernelMode::kCrossCorrelation: return ToString(static_cast<int>(value))+" (cross-correlation)";
+ case KernelMode::kConvolution: return ToString(static_cast<int>(value))+" (convolution)";
+ }
+}
+template <>
std::string ToString(StatusCode value) {
return std::to_string(static_cast<int>(value));
}
@@ -281,6 +288,7 @@ template Side GetArgument<Side>(const std::vector<std::string>&, std::string&, c
template Triangle GetArgument<Triangle>(const std::vector<std::string>&, std::string&, const std::string&, const Triangle);
template Diagonal GetArgument<Diagonal>(const std::vector<std::string>&, std::string&, const std::string&, const Diagonal);
template Precision GetArgument<Precision>(const std::vector<std::string>&, std::string&, const std::string&, const Precision);
+template KernelMode GetArgument<KernelMode>(const std::vector<std::string>&, std::string&, const std::string&, const KernelMode);
// =================================================================================================
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index fcc1c57f..23486d35 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -69,6 +69,7 @@ constexpr auto kArgBTransp = "transB";
constexpr auto kArgSide = "side";
constexpr auto kArgTriangle = "triangle";
constexpr auto kArgDiagonal = "diagonal";
+constexpr auto kArgKernelMode = "kernel_mode";
constexpr auto kArgXInc = "incx";
constexpr auto kArgYInc = "incy";
constexpr auto kArgXOffset = "offx";
@@ -183,6 +184,7 @@ struct Arguments {
Side side = Side::kLeft;
Triangle triangle = Triangle::kUpper;
Diagonal diagonal = Diagonal::kUnit;
+ KernelMode kernel_mode = KernelMode::kCrossCorrelation;
size_t x_inc = 1;
size_t y_inc = 1;
size_t x_offset = 0;