From ad1227c4f2934b0f60c0030101e18b8fb21daf8c Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 30 Dec 2017 18:45:06 +0100 Subject: Added optional temp-buffer argument to C++ interface of GEMM --- include/clblast.h | 4 +++- src/clblast.cpp | 18 +++++++++++------- src/clpp11.hpp | 9 +++++---- src/cupp11.hpp | 6 +++--- src/routines/level3/xgemm.cpp | 24 +++++++++++++++++------- src/routines/level3/xgemm.hpp | 6 ++++-- 6 files changed, 43 insertions(+), 24 deletions(-) diff --git a/include/clblast.h b/include/clblast.h index 3318768a..a05b487f 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -97,6 +97,7 @@ enum class StatusCode { kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small // Custom additional status codes for CLBlast + kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small kInvalidBatchCount = -2049, // The batch count needs to be positive kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel @@ -520,7 +521,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, - cl_command_queue* queue, cl_event* event = nullptr); + cl_command_queue* queue, cl_event* event = nullptr, + cl_mem temp_buffer = nullptr); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template diff --git a/src/clblast.cpp b/src/clblast.cpp index e38a75ca..06449840 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -1651,17 +1651,21 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, - cl_command_queue* queue, cl_event* event) { + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { // optional argument try { auto queue_cpp = Queue(*queue); auto routine = Xgemm(queue_cpp, event); + auto temp_buffer_provided = temp_buffer != nullptr; + auto temp_buffer_cpp = temp_buffer_provided ? Buffer(temp_buffer) : Buffer(nullptr); routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, - Buffer(c_buffer), c_offset, c_ld); + Buffer(c_buffer), c_offset, c_ld, + temp_buffer_cpp, temp_buffer_provided); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } @@ -1672,7 +1676,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const float, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, @@ -1680,7 +1684,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const double, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, @@ -1688,7 +1692,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const float2, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, @@ -1696,7 +1700,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, cons const cl_mem, const size_t, const size_t, const double2, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, @@ -1704,7 +1708,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const T const cl_mem, const size_t, const size_t, const half, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template diff --git a/src/clpp11.hpp b/src/clpp11.hpp index 6ebf1322..2119f26b 100644 --- a/src/clpp11.hpp +++ b/src/clpp11.hpp @@ -614,10 +614,11 @@ class Buffer { } // Regular constructor with memory management. If this class does not own the buffer object, then - // the memory will not be freed automatically afterwards. + // the memory will not be freed automatically afterwards. If the size is set to 0, this will + // become a stub containing a nullptr explicit Buffer(const Context &context, const BufferAccess access, const size_t size): - buffer_(new cl_mem, [access](cl_mem* m) { - if (access != BufferAccess::kNotOwned) { CheckError(clReleaseMemObject(*m)); } + buffer_(new cl_mem, [access, size](cl_mem* m) { + if (access != BufferAccess::kNotOwned && size > 0) { CheckError(clReleaseMemObject(*m)); } delete m; }), access_(access) { @@ -625,7 +626,7 @@ class Buffer { if (access_ == BufferAccess::kReadOnly) { flags = CL_MEM_READ_ONLY; } if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; } auto status = CL_SUCCESS; - *buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status); + *buffer_ = (size > 0) ? clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status) : nullptr; CLCudaAPIError::Check(status, "clCreateBuffer"); } diff --git a/src/cupp11.hpp b/src/cupp11.hpp index eb177ca2..509ae3e8 100644 --- a/src/cupp11.hpp +++ b/src/cupp11.hpp @@ -549,12 +549,12 @@ public: // Regular constructor with memory management. If this class does not own the buffer object, then // the memory will not be freed automatically afterwards. explicit Buffer(const Context &, const BufferAccess access, const size_t size): - buffer_(new CUdeviceptr, [access](CUdeviceptr* m) { - if (access != BufferAccess::kNotOwned) { CheckError(cuMemFree(*m)); } + buffer_(new CUdeviceptr, [access, size](CUdeviceptr* m) { + if (access != BufferAccess::kNotOwned && size > 0) { CheckError(cuMemFree(*m)); } delete m; }), access_(access) { - CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T))); + if (size > 0) { CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T))); } } // As above, but now with read/write access as a default diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 1fe10462..4c1b9558 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -61,7 +61,8 @@ void Xgemm::DoGemm(const Layout layout, const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, const Buffer &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, - const Buffer &c_buffer, const size_t c_offset, const size_t c_ld) { + const Buffer &c_buffer, const size_t c_offset, const size_t c_ld, + const Buffer &temp_buffer, const bool temp_buffer_provided) { // optional arguments // Computes the transpose/conjugate options and sets the a/b/c sizes based on that bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; @@ -94,7 +95,8 @@ void Xgemm::DoGemm(const Layout layout, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, c_buffer, c_offset, c_ld, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - a_one, a_two, b_one, b_two, c_one, c_two); + a_one, a_two, b_one, b_two, c_one, c_two, + temp_buffer, temp_buffer_provided); } } @@ -114,7 +116,8 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, const bool a_conjugate, const bool b_conjugate, const size_t a_one, const size_t a_two, const size_t b_one, const size_t b_two, - const size_t c_one, const size_t c_two) { + const size_t c_one, const size_t c_two, + const Buffer &temp_buffer, const bool temp_buffer_provided) { // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(m, db_["MWG"]); @@ -143,12 +146,19 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, // Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case // when no temporary buffer is needed, but that's just to make it compile: it is never used. - const auto temp_buffer = (temp_size > 0) ? Buffer(context_, temp_size) : a_buffer; + const auto temp_buffer_all = (temp_buffer_provided) ? temp_buffer : + ((temp_size > 0) ? Buffer(context_, temp_size) : a_buffer); + + // Verifies if the provided temporary buffer is large enough + if (temp_buffer_provided) { + const auto required_size = temp_size * sizeof(T); + if (temp_buffer_all.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryTemp); } + } // Sets the buffer pointers for (temp) matrices A, B, and C - const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer; - const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer; - const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer; + const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer_all; + const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer_all; + const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer_all; // Events of all kernels (including pre/post processing kernels) auto eventWaitList = std::vector(); diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index 25b1f5c9..b354de1b 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -158,7 +158,8 @@ class Xgemm: public Routine { const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, const Buffer &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, - const Buffer &c_buffer, const size_t c_offset, const size_t c_ld); + const Buffer &c_buffer, const size_t c_offset, const size_t c_ld, + const Buffer &temp_buffer = Buffer(nullptr), const bool temp_buffer_provided = false); // Indirect version of GEMM (with pre and post-processing kernels) void GemmIndirect(const size_t m, const size_t n, const size_t k, @@ -171,7 +172,8 @@ class Xgemm: public Routine { const bool a_conjugate, const bool b_conjugate, const size_t a_one, const size_t a_two, const size_t b_one, const size_t b_two, - const size_t c_one, const size_t c_two); + const size_t c_one, const size_t c_two, + const Buffer &temp_buffer, const bool temp_buffer_provided); // Direct version of GEMM (no pre and post-processing kernels) void GemmDirect(const size_t m, const size_t n, const size_t k, -- cgit v1.2.3