summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-10 21:15:29 +0100
committerGitHub <noreply@github.com>2017-03-10 21:15:29 +0100
commitde3500ed18ddb39261ffa270f460909571276462 (patch)
treeb515368fcd1e39afb5805f67796b082ccc8066f9 /src/clblast.cpp
parent37228c90988509acef9e8a892a752300b7645210 (diff)
parent3846f44eaf389ee24a698d4947e5c16bd14c3d0e (diff)
Merge pull request #141 from CNugteren/axpy_batched
Added the batched version of the AXPY routine
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index a63d766c..d3db8edf 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -71,6 +71,7 @@
// Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp"
+#include "routines/levelx/xaxpybatched.hpp"
namespace clblast {
@@ -2172,6 +2173,64 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
const cl_mem, const size_t, const size_t,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+
+// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
+template <typename T>
+StatusCode AxpyBatched(const size_t n,
+ const T *alphas,
+ const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc,
+ cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = XaxpyBatched<T>(queue_cpp, event);
+ auto alphas_cpp = std::vector<T>();
+ auto x_offsets_cpp = std::vector<size_t>();
+ auto y_offsets_cpp = std::vector<size_t>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ x_offsets_cpp.push_back(x_offsets[batch]);
+ y_offsets_cpp.push_back(y_offsets[batch]);
+ }
+ routine.DoAxpyBatched(n,
+ alphas_cpp,
+ Buffer<T>(x_buffer), x_offsets_cpp, x_inc,
+ Buffer<T>(y_buffer), y_offsets_cpp, y_inc,
+ batch_count);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API AxpyBatched<float>(const size_t,
+ const float*,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API AxpyBatched<double>(const size_t,
+ const double*,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API AxpyBatched<float2>(const size_t,
+ const float2*,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API AxpyBatched<double2>(const size_t,
+ const double2*,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
+ const half*,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
// =================================================================================================
// Clears the cache of stored binaries