From 6d1e30e61f5ef73f0a83e12f064cae64644034ca Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 28 Dec 2017 14:46:45 +0100 Subject: Added interface to compute the required temporary buffer size for GEMM --- src/clblast.cpp | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) (limited to 'src/clblast.cpp') diff --git a/src/clblast.cpp b/src/clblast.cpp index 7d2c2cef..e38a75ca 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2332,5 +2332,58 @@ template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const size_t, cl_command_queue*, cl_event*); +// ================================================================================================= + +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + RawCommandQueue* queue, size_t& temp_buffer_size) { + try { + + // Retrieves the tuning database + const auto queue_cpp = Queue(*queue); + const auto device = queue_cpp.GetDevice(); + const auto kernel_names = std::vector{"Xgemm", "GemmRoutine"}; + Databases db(kernel_names); + Routine::InitDatabase(device, kernel_names, PrecisionValue(), {}, db); + + // Computes the buffer size + if (Xgemm::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { + temp_buffer_size = 0; + } + else { + temp_buffer_size = Xgemm::GetTempSize(layout, a_transpose, b_transpose, m, n, k, + a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, + db["MWG"], db["NWG"], db["KWG"]); + } + temp_buffer_size *= sizeof(T); // translate from num-elements to bytes + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); + // ================================================================================================= } // namespace clblast -- cgit v1.2.3