diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-07 14:27:15 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-07 14:27:15 +0100 |
commit | 9fb2c61b256ccf66b6a7b6f605008125288d60cf (patch) | |
tree | 2df0c0ed7a5be8e7f1b78131467e8620a2266da7 /src/clblast_c.cpp | |
parent | 0c48c6e6c4cd953523a10bcb804fde67e4650a57 (diff) |
Added API and tests for new GemmStridedBatched routine
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r-- | src/clblast_c.cpp | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 24697779..aa52cbca 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3846,6 +3846,133 @@ CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastT } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } +// GEMM +CLBlastStatusCode CLBlastSgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const float alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const float beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, a_stride, + b_buffer, b_offset, b_ld, b_stride, + beta, + c_buffer, c_offset, c_ld, c_stride, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastDgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const double alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const double beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, a_stride, + b_buffer, b_offset, b_ld, b_stride, + beta, + c_buffer, c_offset, c_ld, c_stride, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastCgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_float2 alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const cl_float2 beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + float2{alpha.s[0], alpha.s[1]}, + a_buffer, a_offset, a_ld, a_stride, + b_buffer, b_offset, b_ld, b_stride, + float2{beta.s[0], beta.s[1]}, + c_buffer, c_offset, c_ld, c_stride, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastZgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_double2 alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const cl_double2 beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + double2{alpha.s[0], alpha.s[1]}, + a_buffer, a_offset, a_ld, a_stride, + b_buffer, b_offset, b_ld, b_stride, + double2{beta.s[0], beta.s[1]}, + c_buffer, c_offset, c_ld, c_stride, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, a_stride, + b_buffer, b_offset, b_ld, b_stride, + beta, + c_buffer, c_offset, c_ld, c_stride, + batch_count, + queue, event) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + // ================================================================================================= // Clears the cache of stored binaries |