summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-30 18:45:06 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-30 18:45:06 +0100
commitad1227c4f2934b0f60c0030101e18b8fb21daf8c (patch)
tree00db3a1af9a52c93df4e9473e05fc8f636838e98
parent6d1e30e61f5ef73f0a83e12f064cae64644034ca (diff)
Added optional temp-buffer argument to C++ interface of GEMM
-rw-r--r--include/clblast.h4
-rw-r--r--src/clblast.cpp18
-rw-r--r--src/clpp11.hpp9
-rw-r--r--src/cupp11.hpp6
-rw-r--r--src/routines/level3/xgemm.cpp24
-rw-r--r--src/routines/level3/xgemm.hpp6
6 files changed, 43 insertions, 24 deletions
diff --git a/include/clblast.h b/include/clblast.h
index 3318768a..a05b487f 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -97,6 +97,7 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
kInvalidBatchCount = -2049, // The batch count needs to be positive
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
@@ -520,7 +521,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
- cl_command_queue* queue, cl_event* event = nullptr);
+ cl_command_queue* queue, cl_event* event = nullptr,
+ cl_mem temp_buffer = nullptr);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
diff --git a/src/clblast.cpp b/src/clblast.cpp
index e38a75ca..06449840 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -1651,17 +1651,21 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
- cl_command_queue* queue, cl_event* event) {
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) { // optional argument
try {
auto queue_cpp = Queue(*queue);
auto routine = Xgemm<T>(queue_cpp, event);
+ auto temp_buffer_provided = temp_buffer != nullptr;
+ auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);
routine.DoGemm(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld,
Buffer<T>(b_buffer), b_offset, b_ld,
beta,
- Buffer<T>(c_buffer), c_offset, c_ld);
+ Buffer<T>(c_buffer), c_offset, c_ld,
+ temp_buffer_cpp, temp_buffer_provided);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
@@ -1672,7 +1676,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const float,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
@@ -1680,7 +1684,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const double,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
@@ -1688,7 +1692,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const float2,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
@@ -1696,7 +1700,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const cl_mem, const size_t, const size_t,
const double2,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
@@ -1704,7 +1708,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
const cl_mem, const size_t, const size_t,
const half,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index 6ebf1322..2119f26b 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -614,10 +614,11 @@ class Buffer {
}
// Regular constructor with memory management. If this class does not own the buffer object, then
- // the memory will not be freed automatically afterwards.
+ // the memory will not be freed automatically afterwards. If the size is set to 0, this will
+ // become a stub containing a nullptr
explicit Buffer(const Context &context, const BufferAccess access, const size_t size):
- buffer_(new cl_mem, [access](cl_mem* m) {
- if (access != BufferAccess::kNotOwned) { CheckError(clReleaseMemObject(*m)); }
+ buffer_(new cl_mem, [access, size](cl_mem* m) {
+ if (access != BufferAccess::kNotOwned && size > 0) { CheckError(clReleaseMemObject(*m)); }
delete m;
}),
access_(access) {
@@ -625,7 +626,7 @@ class Buffer {
if (access_ == BufferAccess::kReadOnly) { flags = CL_MEM_READ_ONLY; }
if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; }
auto status = CL_SUCCESS;
- *buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status);
+ *buffer_ = (size > 0) ? clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status) : nullptr;
CLCudaAPIError::Check(status, "clCreateBuffer");
}
diff --git a/src/cupp11.hpp b/src/cupp11.hpp
index eb177ca2..509ae3e8 100644
--- a/src/cupp11.hpp
+++ b/src/cupp11.hpp
@@ -549,12 +549,12 @@ public:
// Regular constructor with memory management. If this class does not own the buffer object, then
// the memory will not be freed automatically afterwards.
explicit Buffer(const Context &, const BufferAccess access, const size_t size):
- buffer_(new CUdeviceptr, [access](CUdeviceptr* m) {
- if (access != BufferAccess::kNotOwned) { CheckError(cuMemFree(*m)); }
+ buffer_(new CUdeviceptr, [access, size](CUdeviceptr* m) {
+ if (access != BufferAccess::kNotOwned && size > 0) { CheckError(cuMemFree(*m)); }
delete m;
}),
access_(access) {
- CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T)));
+ if (size > 0) { CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T))); }
}
// As above, but now with read/write access as a default
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 1fe10462..4c1b9558 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -61,7 +61,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
@@ -94,7 +95,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
- a_one, a_two, b_one, b_two, c_one, c_two);
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ temp_buffer, temp_buffer_provided);
}
}
@@ -114,7 +116,8 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two,
const size_t b_one, const size_t b_two,
- const size_t c_one, const size_t c_two) {
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) {
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(m, db_["MWG"]);
@@ -143,12 +146,19 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
// Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case
// when no temporary buffer is needed, but that's just to make it compile: it is never used.
- const auto temp_buffer = (temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer;
+ const auto temp_buffer_all = (temp_buffer_provided) ? temp_buffer :
+ ((temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer);
+
+ // Verifies if the provided temporary buffer is large enough
+ if (temp_buffer_provided) {
+ const auto required_size = temp_size * sizeof(T);
+ if (temp_buffer_all.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryTemp); }
+ }
// Sets the buffer pointers for (temp) matrices A, B, and C
- const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer;
- const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer;
- const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer;
+ const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer_all;
+ const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer_all;
+ const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer_all;
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index 25b1f5c9..b354de1b 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -158,7 +158,8 @@ class Xgemm: public Routine {
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld);
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer = Buffer<T>(nullptr), const bool temp_buffer_provided = false);
// Indirect version of GEMM (with pre and post-processing kernels)
void GemmIndirect(const size_t m, const size_t n, const size_t k,
@@ -171,7 +172,8 @@ class Xgemm: public Routine {
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two,
const size_t b_one, const size_t b_two,
- const size_t c_one, const size_t c_two);
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided);
// Direct version of GEMM (no pre and post-processing kernels)
void GemmDirect(const size_t m, const size_t n, const size_t k,