summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/api.md81
-rw-r--r--include/clblast_c.h86
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/clblast_c.cpp239
4 files changed, 406 insertions, 4 deletions
diff --git a/doc/api.md b/doc/api.md
index 4a99cf9e..3e26688c 100644
--- a/doc/api.md
+++ b/doc/api.md
@@ -2236,6 +2236,50 @@ CLBlastStatusCode CLBlastHgemm(const CLBlastLayout layout, const CLBlastTranspos
cl_command_queue* queue, cl_event* event)
```
+C API with temporary buffer (user to allocate & pass `temp_buffer` with size provided by xGemmTempBufferSize() ):
+```
+CLBlastStatusCode CLBlastSgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const float beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer)
+CLBlastStatusCode CLBlastDgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const double beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer)
+CLBlastStatusCode CLBlastCgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_float2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer)
+CLBlastStatusCode CLBlastZgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_double2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer)
+CLBlastStatusCode CLBlastHgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer)
+```
+
Arguments to GEMM:
* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
@@ -3355,8 +3399,43 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
cl_command_queue* queue, size_t& temp_buffer_size)
```
-A C API is not available for this function.
+C API:
+```
+CLBlastStatusCode CLBlastSGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t* temp_buffer_size)
+
+CLBlastStatusCode CLBlastDGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t* temp_buffer_size)
+
+CLBlastStatusCode CLBlastCGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t* temp_buffer_size)
+CLBlastStatusCode CLBlastZGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t* temp_buffer_size)
+
+CLBlastStatusCode CLBlastHGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t* temp_buffer_size)
+```
Arguments to GemmTempBufferSize:
* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
diff --git a/include/clblast_c.h b/include/clblast_c.h
index a00aca45..051871ce 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -96,6 +96,7 @@ typedef enum CLBlastStatusCode_ {
CLBlastInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ CLBlastInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
CLBlastInvalidBatchCount = -2049, // The batch count needs to be positive
CLBlastInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
CLBlastMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
@@ -1536,6 +1537,91 @@ CLBlastStatusCode PUBLIC_API CLBlastHgemmStridedBatched(const CLBlastLayout layo
cl_command_queue* queue, cl_event* event);
// =================================================================================================
+// General matrix-matrix multiplication with temporary buffer from user (optional, for advanced users): SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
+CLBlastStatusCode PUBLIC_API CLBlastSgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const float beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer);
+CLBlastStatusCode PUBLIC_API CLBlastDgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const double beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer);
+CLBlastStatusCode PUBLIC_API CLBlastCgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_float2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer);
+CLBlastStatusCode PUBLIC_API CLBlastZgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_double2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer);
+CLBlastStatusCode PUBLIC_API CLBlastHgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event, cl_mem temp_buffer);
+
+// =================================================================================================
+// Retrieves the required size of the temporary buffer for the GEMM kernel: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM (optional)
+CLBlastStatusCode PUBLIC_API CLBlastSGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size);
+
+CLBlastStatusCode PUBLIC_API CLBlastDGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size);
+
+CLBlastStatusCode PUBLIC_API CLBlastCGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size);
+
+CLBlastStatusCode PUBLIC_API CLBlastZGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size);
+
+CLBlastStatusCode PUBLIC_API CLBlastHGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size);
+
+// =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index d8208bd1..fddfd86e 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -49,8 +49,8 @@ FILES = [
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
-HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21, 288]
-FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55, 1]
+HEADER_LINES = [123, 21, 127, 24, 29, 41, 29, 65, 32, 95, 21, 288]
+FOOTER_LINES = [41, 56, 112, 275, 6, 6, 6, 9, 2, 41, 55, 1]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 123
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index f9592f14..b91ad308 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -4072,6 +4072,243 @@ CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const C
// =================================================================================================
+// GEMM with temporary buffer (optional, for advanced users)
+CLBlastStatusCode CLBlastSgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const float beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastDgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const double beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastCgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_float2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ float2{alpha.s[0], alpha.s[1]},
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ float2{beta.s[0], beta.s[1]},
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastZgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_double2 beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ double2{alpha.s[0], alpha.s[1]},
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ double2{beta.s[0], beta.s[1]},
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastHgemmWithTempBuffer(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event, temp_buffer)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================
+
+// GEMM get temporary buffer size
+CLBlastStatusCode CLBlastSGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<float>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastDGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<double>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastCGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<float2>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastZGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<double2>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+CLBlastStatusCode CLBlastHGemmTempBufferSize(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue,
+ size_t* temp_buffer_size){
+
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmTempBufferSize<half>(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ a_offset, a_ld,
+ b_offset, b_ld,
+ c_offset, c_ld,
+ queue, *temp_buffer_size)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================
+
// Clears the cache of stored binaries
CLBlastStatusCode CLBlastClearCache() {
try {
@@ -4106,4 +4343,4 @@ CLBlastStatusCode PUBLIC_API CLBlastOverrideParameters(const cl_device_id device
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-// =================================================================================================
+// ================================================================================================= \ No newline at end of file