From 49e04c7fce8fed45559e143137cef3a1a36328cc Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Fri, 10 Mar 2017 21:24:35 +0100 Subject: Added API and test infrastructure for the batched GEMM routine --- src/clblast.cpp | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) (limited to 'src/clblast.cpp') diff --git a/src/clblast.cpp b/src/clblast.cpp index d3db8edf..a8bcf91d 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -72,6 +72,7 @@ // Level-x includes (non-BLAS) #include "routines/levelx/xomatcopy.hpp" #include "routines/levelx/xaxpybatched.hpp" +#include "routines/levelx/xgemmbatched.hpp" namespace clblast { @@ -2231,6 +2232,89 @@ template StatusCode PUBLIC_API AxpyBatched(const size_t, cl_mem, const size_t*, const size_t, const size_t, cl_command_queue*, cl_event*); + +// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED +template +StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const T *alphas, + const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, + const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, + const T *betas, + cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = XgemmBatched(queue_cpp, event); + auto alphas_cpp = std::vector(); + auto betas_cpp = std::vector(); + auto a_offsets_cpp = std::vector(); + auto b_offsets_cpp = std::vector(); + auto c_offsets_cpp = std::vector(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + betas_cpp.push_back(betas[batch]); + a_offsets_cpp.push_back(a_offsets[batch]); + b_offsets_cpp.push_back(b_offsets[batch]); + c_offsets_cpp.push_back(c_offsets[batch]); + } + routine.DoGemmBatched(layout, a_transpose, b_transpose, + m, n, k, + alphas_cpp, + Buffer(a_buffer), a_offsets_cpp, a_ld, + Buffer(b_buffer), b_offsets_cpp, b_ld, + betas_cpp, + Buffer(c_buffer), c_offsets_cpp, c_ld, + batch_count); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const float*, + const cl_mem, const size_t*, const size_t, + const cl_mem, const size_t*, const size_t, + const float*, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const double*, + const cl_mem, const size_t*, const size_t, + const cl_mem, const size_t*, const size_t, + const double*, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const float2*, + const cl_mem, const size_t*, const size_t, + const cl_mem, const size_t*, const size_t, + const float2*, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const double2*, + const cl_mem, const size_t*, const size_t, + const cl_mem, const size_t*, const size_t, + const double2*, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const half*, + const cl_mem, const size_t*, const size_t, + const cl_mem, const size_t*, const size_t, + const half*, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); // ================================================================================================= // Clears the cache of stored binaries -- cgit v1.2.3