diff options
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r-- | src/clblast.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index 026285bb..3a96136a 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2254,12 +2254,20 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM template <typename T> -StatusCode Convgemm(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const cl_mem, const size_t, - const cl_mem, const size_t, - cl_mem, const size_t, - cl_command_queue*, cl_event*) { - return StatusCode::kNotImplemented; +StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, + const cl_mem im_buffer, const size_t im_offset, + const cl_mem kernel_buffer, const size_t kernel_offset, + cl_mem result_buffer, const size_t result_offset, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = Xconvgemm<T>(queue_cpp, event); + routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + Buffer<T>(im_buffer), im_offset, + Buffer<T>(kernel_buffer), kernel_offset, + Buffer<T>(result_buffer), result_offset); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, |