diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-10 21:15:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-10 21:15:29 +0100 |
commit | de3500ed18ddb39261ffa270f460909571276462 (patch) | |
tree | b515368fcd1e39afb5805f67796b082ccc8066f9 /src/clblast.cpp | |
parent | 37228c90988509acef9e8a892a752300b7645210 (diff) | |
parent | 3846f44eaf389ee24a698d4947e5c16bd14c3d0e (diff) |
Merge pull request #141 from CNugteren/axpy_batched
Added the batched version of the AXPY routine
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r-- | src/clblast.cpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index a63d766c..d3db8edf 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -71,6 +71,7 @@ // Level-x includes (non-BLAS) #include "routines/levelx/xomatcopy.hpp" +#include "routines/levelx/xaxpybatched.hpp" namespace clblast { @@ -2172,6 +2173,64 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); + +// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED +template <typename T> +StatusCode AxpyBatched(const size_t n, + const T *alphas, + const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, + cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = XaxpyBatched<T>(queue_cpp, event); + auto alphas_cpp = std::vector<T>(); + auto x_offsets_cpp = std::vector<size_t>(); + auto y_offsets_cpp = std::vector<size_t>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + x_offsets_cpp.push_back(x_offsets[batch]); + y_offsets_cpp.push_back(y_offsets[batch]); + } + routine.DoAxpyBatched(n, + alphas_cpp, + Buffer<T>(x_buffer), x_offsets_cpp, x_inc, + Buffer<T>(y_buffer), y_offsets_cpp, y_inc, + batch_count); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API AxpyBatched<float>(const size_t, + const float*, + const cl_mem, const size_t*, const size_t, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<double>(const size_t, + const double*, + const cl_mem, const size_t*, const size_t, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<float2>(const size_t, + const float2*, + const cl_mem, const size_t*, const size_t, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<double2>(const size_t, + const double2*, + const cl_mem, const size_t*, const size_t, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<half>(const size_t, + const half*, + const cl_mem, const size_t*, const size_t, + cl_mem, const size_t*, const size_t, + const size_t, + cl_command_queue*, cl_event*); // ================================================================================================= // Clears the cache of stored binaries |