diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/clblast.cpp | 59 | ||||
-rw-r--r-- | src/clblast_c.cpp | 107 |
2 files changed, 166 insertions, 0 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index a63d766c..55562419 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -71,6 +71,7 @@ // Level-x includes (non-BLAS) #include "routines/levelx/xomatcopy.hpp" +#include "routines/levelx/xaxpybatched.hpp" namespace clblast { @@ -2172,6 +2173,64 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); + +// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED +template <typename T> +StatusCode AxpyBatched(const size_t n, + const T *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = XaxpyBatched<T>(queue_cpp, event); + auto alphas_cpp = std::vector<T>(); + auto x_buffers_cpp = std::vector<Buffer<T>>(); + auto y_buffers_cpp = std::vector<Buffer<T>>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + x_buffers_cpp.push_back(Buffer<T>(x_buffers[batch])); + y_buffers_cpp.push_back(Buffer<T>(y_buffers[batch])); + } + routine.DoAxpyBatched(n, + alphas_cpp, + x_buffers_cpp, x_offset, x_inc, + y_buffers_cpp, y_offset, y_inc, + batch_count); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API AxpyBatched<float>(const size_t, + const float*, + const cl_mem*, const size_t, const size_t, + cl_mem*, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<double>(const size_t, + const double*, + const cl_mem*, const size_t, const size_t, + cl_mem*, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<float2>(const size_t, + const float2*, + const cl_mem*, const size_t, const size_t, + cl_mem*, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<double2>(const size_t, + const double2*, + const cl_mem*, const size_t, const size_t, + cl_mem*, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API AxpyBatched<half>(const size_t, + const half*, + const cl_mem*, const size_t, const size_t, + cl_mem*, const size_t, const size_t, + const size_t, + cl_command_queue*, cl_event*); // ================================================================================================= // Clears the cache of stored binaries diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 6018bcfa..83450e6f 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3447,6 +3447,113 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } +// AXPY +CLBlastStatusCode CLBlastSaxpyBatched(const size_t n, + const float *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<float>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastDaxpyBatched(const size_t n, + const double *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<double>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastCaxpyBatched(const size_t n, + const cl_float2 *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<float2>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]}); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastZaxpyBatched(const size_t n, + const cl_double2 *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<double2>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]}); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastHaxpyBatched(const size_t n, + const cl_half *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<half>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + // ================================================================================================= // Clears the cache of stored binaries |