summaryrefslogtreecommitdiff
path: root/src/clblast_cuda.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-05-06 11:35:34 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-05-06 11:35:34 +0200
commit2d1f6ba7fe842ba938490fc599b6ebd209b6560b (patch)
treef1a284e5dc0163b7fed938a3efeb39432b9d3788 /src/clblast_cuda.cpp
parent2776d761768295b01a8be7c333dbb337805d7f77 (diff)
Added convgemm skeleton, test infrastructure, and first reference implementation
Diffstat (limited to 'src/clblast_cuda.cpp')
-rw-r--r--src/clblast_cuda.cpp22
1 files changed, 16 insertions, 6 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index f89fb77d..5aab1626 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2352,12 +2352,22 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
template <typename T>
-StatusCode Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
- const CUdeviceptr, const size_t,
- const CUdeviceptr, const size_t,
- CUdeviceptr, const size_t,
- const CUcontext, const CUdevice) {
- return StatusCode::kNotImplemented;
+StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+ const CUdeviceptr im_buffer, const size_t im_offset,
+ const CUdeviceptr kernel_buffer, const size_t kernel_offset,
+ CUdeviceptr result_buffer, const size_t result_offset,
+ const CUcontext context, const CUdevice device) {
+ try {
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
+ auto routine = Xconvgemm<T>(queue_cpp, nullptr);
+ routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ Buffer<T>(im_buffer), im_offset,
+ Buffer<T>(kernel_buffer), kernel_offset,
+ Buffer<T>(result_buffer), result_offset);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,