summaryrefslogtreecommitdiff
path: root/src/clblast_c.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r--src/clblast_c.cpp157
1 files changed, 157 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index b09f8c54..f5104bad 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3554,6 +3554,163 @@ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
+// GEMM
+CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const float *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<float>();
+ auto betas_cpp = std::vector<float>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const double *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<double>();
+ auto betas_cpp = std::vector<double>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_float2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<float2>();
+ auto betas_cpp = std::vector<float2>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]});
+ betas_cpp.push_back(float2{betas[batch].s[0], betas[batch].s[1]});
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_double2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<double2>();
+ auto betas_cpp = std::vector<double2>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]});
+ betas_cpp.push_back(double2{betas[batch].s[0], betas[batch].s[1]});
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_half *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<half>();
+ auto betas_cpp = std::vector<half>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
// =================================================================================================
// Clears the cache of stored binaries