diff options
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r-- | src/clblast.cpp | 71 |
1 files changed, 64 insertions, 7 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index 7d2c2cef..f5e2f1be 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -1651,17 +1651,21 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, - cl_command_queue* queue, cl_event* event) { + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { try { auto queue_cpp = Queue(*queue); auto routine = Xgemm<T>(queue_cpp, event); + const auto temp_buffer_provided = temp_buffer != nullptr; + auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr); routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer<T>(a_buffer), a_offset, a_ld, Buffer<T>(b_buffer), b_offset, b_ld, beta, - Buffer<T>(c_buffer), c_offset, c_ld); + Buffer<T>(c_buffer), c_offset, c_ld, + temp_buffer_cpp, temp_buffer_provided); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } @@ -1672,7 +1676,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const float, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, @@ -1680,7 +1684,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const double, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, @@ -1688,7 +1692,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const float2, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, @@ -1696,7 +1700,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons const cl_mem, const size_t, const size_t, const double2, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, @@ -1704,7 +1708,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T const cl_mem, const size_t, const size_t, const half, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> @@ -2333,4 +2337,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, cl_command_queue*, cl_event*); // ================================================================================================= + +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template <typename T> +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, size_t& temp_buffer_size) { + try { + + // Retrieves the tuning database + const auto queue_cpp = Queue(*queue); + const auto device = queue_cpp.GetDevice(); + const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"}; + Databases db(kernel_names); + Routine::InitDatabase(device, kernel_names, PrecisionValue<T>(), {}, db); + + // Computes the buffer size + if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { + temp_buffer_size = 0; + } + else { + temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k, + a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, + db["MWG"], db["NWG"], db["KWG"]); + } + temp_buffer_size *= sizeof(T); // translate from num-elements to bytes + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); + +// ================================================================================================= } // namespace clblast |