diff options
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r-- | src/clblast_c.cpp | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 6018bcfa..83450e6f 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3447,6 +3447,113 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } +// AXPY +CLBlastStatusCode CLBlastSaxpyBatched(const size_t n, + const float *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<float>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastDaxpyBatched(const size_t n, + const double *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<double>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastCaxpyBatched(const size_t n, + const cl_float2 *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<float2>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]}); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastZaxpyBatched(const size_t n, + const cl_double2 *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<double2>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]}); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastHaxpyBatched(const size_t n, + const cl_half *alphas, + const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc, + cl_mem *y_buffers, const size_t y_offset, const size_t y_inc, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<half>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::AxpyBatched(n, + alphas_cpp.data(), + x_buffers, x_offset, x_inc, + y_buffers, y_offset, y_inc, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + // ================================================================================================= // Clears the cache of stored binaries |