From 2d1f6ba7fe842ba938490fc599b6ebd209b6560b Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 6 May 2018 11:35:34 +0200 Subject: Added convgemm skeleton, test infrastructure, and first reference implementation --- src/clblast_cuda.cpp | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'src/clblast_cuda.cpp') diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index f89fb77d..5aab1626 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2352,12 +2352,22 @@ template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const si // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM template -StatusCode Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const CUdeviceptr, const size_t, - const CUdeviceptr, const size_t, - CUdeviceptr, const size_t, - const CUcontext, const CUdevice) { - return StatusCode::kNotImplemented; +StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, + const CUdeviceptr im_buffer, const size_t im_offset, + const CUdeviceptr kernel_buffer, const size_t kernel_offset, + CUdeviceptr result_buffer, const size_t result_offset, + const CUcontext context, const CUdevice device) { + try { + const auto context_cpp = Context(context); + const auto device_cpp = Device(device); + auto queue_cpp = Queue(context_cpp, device_cpp); + auto routine = Xconvgemm(queue_cpp, nullptr); + routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + Buffer(im_buffer), im_offset, + Buffer(kernel_buffer), kernel_offset, + Buffer(result_buffer), result_offset); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, -- cgit v1.2.3