summaryrefslogtreecommitdiff
path: root/src/clblast_c.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r--src/clblast_c.cpp107
1 files changed, 107 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 6018bcfa..83450e6f 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3447,6 +3447,113 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
+// AXPY
+CLBlastStatusCode CLBlastSaxpyBatched(const size_t n,
+ const float *alphas,
+ const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
+ cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<float>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::AxpyBatched(n,
+ alphas_cpp.data(),
+ x_buffers, x_offset, x_inc,
+ y_buffers, y_offset, y_inc,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastDaxpyBatched(const size_t n,
+ const double *alphas,
+ const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
+ cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<double>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::AxpyBatched(n,
+ alphas_cpp.data(),
+ x_buffers, x_offset, x_inc,
+ y_buffers, y_offset, y_inc,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastCaxpyBatched(const size_t n,
+ const cl_float2 *alphas,
+ const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
+ cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<float2>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]});
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::AxpyBatched(n,
+ alphas_cpp.data(),
+ x_buffers, x_offset, x_inc,
+ y_buffers, y_offset, y_inc,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastZaxpyBatched(const size_t n,
+ const cl_double2 *alphas,
+ const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
+ cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<double2>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]});
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::AxpyBatched(n,
+ alphas_cpp.data(),
+ x_buffers, x_offset, x_inc,
+ y_buffers, y_offset, y_inc,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
+ const cl_half *alphas,
+ const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
+ cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<half>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::AxpyBatched(n,
+ alphas_cpp.data(),
+ x_buffers, x_offset, x_inc,
+ y_buffers, y_offset, y_inc,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
// =================================================================================================
// Clears the cache of stored binaries