summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-10 21:24:35 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-03-10 21:24:35 +0100
commit49e04c7fce8fed45559e143137cef3a1a36328cc (patch)
treef73a5c280f12cc5e38f6d4fd4e853b8b8e1aa432 /src/clblast.cpp
parentde3500ed18ddb39261ffa270f460909571276462 (diff)
Added API and test infrastructure for the batched GEMM routine
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp84
1 files changed, 84 insertions, 0 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index d3db8edf..a8bcf91d 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -72,6 +72,7 @@
// Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp"
#include "routines/levelx/xaxpybatched.hpp"
+#include "routines/levelx/xgemmbatched.hpp"
namespace clblast {
@@ -2231,6 +2232,89 @@ template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
+
+// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
+template <typename T>
+StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const T *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const T *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = XgemmBatched<T>(queue_cpp, event);
+ auto alphas_cpp = std::vector<T>();
+ auto betas_cpp = std::vector<T>();
+ auto a_offsets_cpp = std::vector<size_t>();
+ auto b_offsets_cpp = std::vector<size_t>();
+ auto c_offsets_cpp = std::vector<size_t>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ a_offsets_cpp.push_back(a_offsets[batch]);
+ b_offsets_cpp.push_back(b_offsets[batch]);
+ c_offsets_cpp.push_back(c_offsets[batch]);
+ }
+ routine.DoGemmBatched(layout, a_transpose, b_transpose,
+ m, n, k,
+ alphas_cpp,
+ Buffer<T>(a_buffer), a_offsets_cpp, a_ld,
+ Buffer<T>(b_buffer), b_offsets_cpp, b_ld,
+ betas_cpp,
+ Buffer<T>(c_buffer), c_offsets_cpp, c_ld,
+ batch_count);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmBatched<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const float*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const float*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const double*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const double*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const float2*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const float2*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const double2*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const double2*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const half*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const half*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
// =================================================================================================
// Clears the cache of stored binaries