summaryrefslogtreecommitdiff
path: root/src/clblast_cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast_cuda.cpp')
-rw-r--r--src/clblast_cuda.cpp35
1 files changed, 35 insertions, 0 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 8927014b..f14806cb 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2350,6 +2350,41 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
+// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
+template <typename T>
+StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+ const CUdeviceptr im_buffer, const size_t im_offset,
+ const CUdeviceptr kernel_buffer, const size_t kernel_offset,
+ CUdeviceptr result_buffer, const size_t result_offset,
+ const CUcontext context, const CUdevice device) {
+ try {
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
+ auto routine = Xconvgemm<T>(queue_cpp, nullptr);
+ routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ Buffer<T>(im_buffer), im_offset,
+ Buffer<T>(kernel_buffer), kernel_offset,
+ Buffer<T>(result_buffer), result_offset);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+
// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
template <typename T>
StatusCode AxpyBatched(const size_t n,