summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp57
1 files changed, 38 insertions, 19 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index e45f504a..180693e7 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2218,79 +2218,94 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
-StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Im2col(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xim2col<T>(queue_cpp, event);
- routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoIm2col(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(col_buffer), col_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Im2col<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
template <typename T>
-StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+StatusCode Col2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xcol2im<T>(queue_cpp, event);
- routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ routine.DoCol2im(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(col_buffer), col_offset,
Buffer<T>(im_buffer), im_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<float2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<double2>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Col2im<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+StatusCode Convgemm(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
@@ -2298,24 +2313,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid
try {
auto queue_cpp = Queue(*queue);
auto routine = Xconvgemm<T>(queue_cpp, event);
- routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ routine.DoConvgemm(kernel_mode,
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
Buffer<T>(im_buffer), im_offset,
Buffer<T>(kernel_buffer), kernel_offset,
Buffer<T>(result_buffer), result_offset);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
-template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<float>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<double>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
-template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+template StatusCode PUBLIC_API Convgemm<half>(const KernelMode,
+ const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,