diff options
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r-- | src/clblast_c.cpp | 239 |
1 files changed, 238 insertions, 1 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index f9592f14..b91ad308 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -4072,6 +4072,243 @@ CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const C // ================================================================================================= +// GEMM with temporary buffer (optional, for advanced users) +CLBlastStatusCode CLBlastSgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const float alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const float beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { + try { + return static_cast<CLBlastStatusCode>( + clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event, temp_buffer) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastDgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const double alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const double beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { + try { + return static_cast<CLBlastStatusCode>( + clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event, temp_buffer) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastCgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_float2 alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_float2 beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { + try { + return static_cast<CLBlastStatusCode>( + clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + float2{alpha.s[0], alpha.s[1]}, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + float2{beta.s[0], beta.s[1]}, + c_buffer, c_offset, c_ld, + queue, event, temp_buffer) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastZgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_double2 alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_double2 beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { + try { + return static_cast<CLBlastStatusCode>( + clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + double2{alpha.s[0], alpha.s[1]}, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + double2{beta.s[0], beta.s[1]}, + c_buffer, c_offset, c_ld, + queue, event, temp_buffer) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastHgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { + try { + return static_cast<CLBlastStatusCode>( + clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event, temp_buffer) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +// ================================================================================================= + +// GEMM get temporary buffer size +CLBlastStatusCode CLBlastSGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, + size_t* temp_buffer_size){ + + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmTempBufferSize<float>(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + a_offset, a_ld, + b_offset, b_ld, + c_offset, c_ld, + queue, *temp_buffer_size) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +CLBlastStatusCode CLBlastDGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, + size_t* temp_buffer_size){ + + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmTempBufferSize<double>(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + a_offset, a_ld, + b_offset, b_ld, + c_offset, c_ld, + queue, *temp_buffer_size) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +CLBlastStatusCode CLBlastCGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, + size_t* temp_buffer_size){ + + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmTempBufferSize<float2>(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + a_offset, a_ld, + b_offset, b_ld, + c_offset, c_ld, + queue, *temp_buffer_size) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +CLBlastStatusCode CLBlastZGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, + size_t* temp_buffer_size){ + + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmTempBufferSize<double2>(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + a_offset, a_ld, + b_offset, b_ld, + c_offset, c_ld, + queue, *temp_buffer_size) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +CLBlastStatusCode CLBlastHGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, + size_t* temp_buffer_size){ + + try { + return static_cast<CLBlastStatusCode>( + clblast::GemmTempBufferSize<half>(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + a_offset, a_ld, + b_offset, b_ld, + c_offset, c_ld, + queue, *temp_buffer_size) + ); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +// ================================================================================================= + // Clears the cache of stored binaries CLBlastStatusCode CLBlastClearCache() { try { @@ -4106,4 +4343,4 @@ CLBlastStatusCode PUBLIC_API CLBlastOverrideParameters(const cl_device_id device } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } -// ================================================================================================= +// =================================================================================================
\ No newline at end of file |