summaryrefslogtreecommitdiff
path: root/src/clblast_c.cpp
diff options
context:
space:
mode:
authorsivagnanamn <sivagnanammurthy@gmail.com>2018-03-03 03:00:17 +0900
committersivagnanamn <sivagnanammurthy@gmail.com>2018-03-03 03:00:17 +0900
commit1433dc67f17a94d1089291f4386d4fe668eb4a62 (patch)
treea43331e8060c436bd7fc4af4ff42e16b96c203a3 /src/clblast_c.cpp
parent1940e670094822f2d01db7390c210eb6ff949457 (diff)
Added C API for getting GEMM temp buffer size
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r--src/clblast_c.cpp239
1 files changed, 238 insertions, 1 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index f9592f14..b91ad308 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -4072,6 +4072,243 @@ CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const C
// =================================================================================================
+// GEMM with temporary buffer (optional, for advanced users)
+CLBlastStatusCode CLBlastSgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const float beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastDgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const double beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastCgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_float2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ float2{alpha.s[0], alpha.s[1]},
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ float2{beta.s[0], beta.s[1]},
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastZgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_double2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ double2{alpha.s[0], alpha.s[1]},
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ double2{beta.s[0], beta.s[1]},
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastHgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================
+
+// GEMM get temporary buffer size
+CLBlastStatusCode CLBlastSGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<float>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastDGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<double>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastCGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<float2>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastZGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<double2>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastHGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<half>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================
+
// Clears the cache of stored binaries
CLBlastStatusCode CLBlastClearCache() {
try {
@@ -4106,4 +4343,4 @@ CLBlastStatusCode PUBLIC_API CLBlastOverrideParameters(const cl_device_id device
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-// =================================================================================================
+// ================================================================================================= \ No newline at end of file