From 9fb2c61b256ccf66b6a7b6f605008125288d60cf Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 7 Jan 2018 14:27:15 +0100 Subject: Added API and tests for new GemmStridedBatched routine --- src/clblast.cpp | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) (limited to 'src/clblast.cpp') diff --git a/src/clblast.cpp b/src/clblast.cpp index f5e2f1be..c4c51538 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2336,6 +2336,77 @@ template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const size_t, cl_command_queue*, cl_event*); +// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED +template +StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const T alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const T beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = XgemmStridedBatched(queue_cpp, event); + routine.DoGemmStridedBatched(layout, a_transpose, b_transpose, + m, n, k, + alpha, + Buffer(a_buffer), a_offset, a_ld, a_stride, + Buffer(b_buffer), b_offset, b_ld, b_stride, + beta, + Buffer(c_buffer), c_offset, c_ld, c_stride, + batch_count); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const float, + const cl_mem, const size_t, const size_t, const size_t, + const cl_mem, const size_t, const size_t, const size_t, + const float, + cl_mem, const size_t, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const double, + const cl_mem, const size_t, const size_t, const size_t, + const cl_mem, const size_t, const size_t, const size_t, + const double, + cl_mem, const size_t, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const float2, + const cl_mem, const size_t, const size_t, const size_t, + const cl_mem, const size_t, const size_t, const size_t, + const float2, + cl_mem, const size_t, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const double2, + const cl_mem, const size_t, const size_t, const size_t, + const cl_mem, const size_t, const size_t, const size_t, + const double2, + cl_mem, const size_t, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, const size_t, + const cl_mem, const size_t, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); + // ================================================================================================= // Retrieves the required size of the temporary buffer for the GEMM kernel (optional) -- cgit v1.2.3