summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-30 18:45:06 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-30 18:45:06 +0100
commitad1227c4f2934b0f60c0030101e18b8fb21daf8c (patch)
tree00db3a1af9a52c93df4e9473e05fc8f636838e98 /src/routines
parent6d1e30e61f5ef73f0a83e12f064cae64644034ca (diff)
Added optional temp-buffer argument to C++ interface of GEMM
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/level3/xgemm.cpp24
-rw-r--r--src/routines/level3/xgemm.hpp6
2 files changed, 21 insertions, 9 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 1fe10462..4c1b9558 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -61,7 +61,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
@@ -94,7 +95,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
- a_one, a_two, b_one, b_two, c_one, c_two);
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ temp_buffer, temp_buffer_provided);
}
}
@@ -114,7 +116,8 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two,
const size_t b_one, const size_t b_two,
- const size_t c_one, const size_t c_two) {
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) {
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(m, db_["MWG"]);
@@ -143,12 +146,19 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
// Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case
// when no temporary buffer is needed, but that's just to make it compile: it is never used.
- const auto temp_buffer = (temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer;
+ const auto temp_buffer_all = (temp_buffer_provided) ? temp_buffer :
+ ((temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer);
+
+ // Verifies if the provided temporary buffer is large enough
+ if (temp_buffer_provided) {
+ const auto required_size = temp_size * sizeof(T);
+ if (temp_buffer_all.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryTemp); }
+ }
// Sets the buffer pointers for (temp) matrices A, B, and C
- const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer;
- const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer;
- const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer;
+ const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer_all;
+ const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer_all;
+ const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer_all;
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index 25b1f5c9..b354de1b 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -158,7 +158,8 @@ class Xgemm: public Routine {
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld);
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer = Buffer<T>(nullptr), const bool temp_buffer_provided = false);
// Indirect version of GEMM (with pre and post-processing kernels)
void GemmIndirect(const size_t m, const size_t n, const size_t k,
@@ -171,7 +172,8 @@ class Xgemm: public Routine {
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two,
const size_t b_one, const size_t b_two,
- const size_t c_one, const size_t c_two);
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided);
// Direct version of GEMM (no pre and post-processing kernels)
void GemmDirect(const size_t m, const size_t n, const size_t k,