summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp71
1 files changed, 64 insertions, 7 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 7d2c2cef..f5e2f1be 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -1651,17 +1651,21 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
- cl_command_queue* queue, cl_event* event) {
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xgemm<T>(queue_cpp, event);
+ const auto temp_buffer_provided = temp_buffer != nullptr;
+ auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);
routine.DoGemm(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld,
Buffer<T>(b_buffer), b_offset, b_ld,
beta,
- Buffer<T>(c_buffer), c_offset, c_ld);
+ Buffer<T>(c_buffer), c_offset, c_ld,
+ temp_buffer_cpp, temp_buffer_provided);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
@@ -1672,7 +1676,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const float,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
@@ -1680,7 +1684,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const double,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
@@ -1688,7 +1692,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const float2,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
@@ -1696,7 +1700,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const cl_mem, const size_t, const size_t,
const double2,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
@@ -1704,7 +1708,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
const cl_mem, const size_t, const size_t,
const half,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
@@ -2333,4 +2337,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
cl_command_queue*, cl_event*);
// =================================================================================================
+
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t& temp_buffer_size) {
+ try {
+
+ // Retrieves the tuning database
+ const auto queue_cpp = Queue(*queue);
+ const auto device = queue_cpp.GetDevice();
+ const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
+ Databases db(kernel_names);
+ Routine::InitDatabase(device, kernel_names, PrecisionValue<T>(), {}, db);
+
+ // Computes the buffer size
+ if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
+ temp_buffer_size = 0;
+ }
+ else {
+ temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
+ a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
+ db["MWG"], db["NWG"], db["KWG"]);
+ }
+ temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+
+// =================================================================================================
} // namespace clblast