diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-10 21:24:35 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-10 21:24:35 +0100 |
commit | 49e04c7fce8fed45559e143137cef3a1a36328cc (patch) | |
tree | f73a5c280f12cc5e38f6d4fd4e853b8b8e1aa432 /src/clblast_c.cpp | |
parent | de3500ed18ddb39261ffa270f460909571276462 (diff) |
Added API and test infrastructure for the batched GEMM routine
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r-- | src/clblast_c.cpp | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index b09f8c54..f5104bad 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3554,6 +3554,163 @@ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n, } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } +// GEMM +CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const float *alphas, + const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, + const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, + const float *betas, + cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<float>(); + auto betas_cpp = std::vector<float>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + betas_cpp.push_back(betas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alphas_cpp.data(), + a_buffer, a_offsets, a_ld, + b_buffer, b_offsets, b_ld, + betas_cpp.data(), + c_buffer, c_offsets, c_ld, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const double *alphas, + const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, + const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, + const double *betas, + cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<double>(); + auto betas_cpp = std::vector<double>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + betas_cpp.push_back(betas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alphas_cpp.data(), + a_buffer, a_offsets, a_ld, + b_buffer, b_offsets, b_ld, + betas_cpp.data(), + c_buffer, c_offsets, c_ld, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_float2 *alphas, + const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, + const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, + const cl_float2 *betas, + cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<float2>(); + auto betas_cpp = std::vector<float2>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]}); + betas_cpp.push_back(float2{betas[batch].s[0], betas[batch].s[1]}); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alphas_cpp.data(), + a_buffer, a_offsets, a_ld, + b_buffer, b_offsets, b_ld, + betas_cpp.data(), + c_buffer, c_offsets, c_ld, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_double2 *alphas, + const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, + const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, + const cl_double2 *betas, + cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<double2>(); + auto betas_cpp = std::vector<double2>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]}); + betas_cpp.push_back(double2{betas[batch].s[0], betas[batch].s[1]}); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alphas_cpp.data(), + a_buffer, a_offsets, a_ld, + b_buffer, b_offsets, b_ld, + betas_cpp.data(), + c_buffer, c_offsets, c_ld, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half *alphas, + const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, + const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, + const cl_half *betas, + cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + auto alphas_cpp = std::vector<half>(); + auto betas_cpp = std::vector<half>(); + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + alphas_cpp.push_back(alphas[batch]); + betas_cpp.push_back(betas[batch]); + } + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alphas_cpp.data(), + a_buffer, a_offsets, a_ld, + b_buffer, b_offsets, b_ld, + betas_cpp.data(), + c_buffer, c_offsets, c_ld, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + // ================================================================================================= // Clears the cache of stored binaries |