summaryrefslogtreecommitdiff
path: root/src/routines/level3/xgemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/level3/xgemm.cpp')
-rw-r--r--src/routines/level3/xgemm.cpp107
1 files changed, 42 insertions, 65 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index edba1f00..4c1b9558 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -19,6 +19,11 @@
namespace clblast {
// =================================================================================================
+// Defines the assumptions of the GEMM kernels
+template <typename T> const bool Xgemm<T>::a_want_rotated_ = false;
+template <typename T> const bool Xgemm<T>::b_want_rotated_ = true;
+template <typename T> const bool Xgemm<T>::c_want_rotated_ = false;
+
// Constructor: forwards to base class constructor
template <typename T>
Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
@@ -56,40 +61,15 @@ void Xgemm<T>::DoGemm(const Layout layout,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
-
- // Makes sure all dimensions are larger than zero
- if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
-
- // Computes whether or not the matrices are transposed in memory. This is based on their layout
- // (row or column-major) and whether or not they are requested to be pre-transposed. Note
- // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of
- // col-major) to be transformed, so transposing requirements are not the same as whether or not
- // the matrix is actually transposed in memory.
- const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
- const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
- const auto c_rotated = (layout == Layout::kRowMajor);
- static const auto a_want_rotated = false;
- static const auto b_want_rotated = true;
- static const auto c_want_rotated = false;
- const auto a_do_transpose = a_rotated != a_want_rotated;
- const auto b_do_transpose = b_rotated != b_want_rotated;
- const auto c_do_transpose = c_rotated != c_want_rotated;
-
- // In case of complex data-types, the transpose can also become a conjugate transpose
- const auto a_conjugate = (a_transpose == Transpose::kConjugate);
- const auto b_conjugate = (b_transpose == Transpose::kConjugate);
-
- // Computes the first and second dimensions of the 3 matrices taking into account whether the
- // matrices are rotated or not
- const auto a_one = (a_rotated) ? k : m;
- const auto a_two = (a_rotated) ? m : k;
- const auto b_one = (b_rotated) ? n : k;
- const auto b_two = (b_rotated) ? k : n;
- const auto c_one = (c_rotated) ? n : m;
- const auto c_two = (c_rotated) ? m : n;
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments
+
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
// Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and
// their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the
@@ -103,11 +83,7 @@ void Xgemm<T>::DoGemm(const Layout layout,
TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld);
// Selects which version of GEMM to run
- const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) *
- static_cast<unsigned long long>(k);
- const auto database_value = static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"]);
- const auto min_indirect_size = database_value * database_value * database_value;
- const auto do_gemm_direct = (m_n_k < min_indirect_size);
+ const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
if (do_gemm_direct) { // for small sizes (single kernel)
GemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
@@ -119,9 +95,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
- a_one, a_two, a_want_rotated,
- b_one, b_two, b_want_rotated,
- c_one, c_two, c_want_rotated);
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ temp_buffer, temp_buffer_provided);
}
}
@@ -139,9 +114,11 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
- const size_t a_one, const size_t a_two, const bool a_want_rotated,
- const size_t b_one, const size_t b_two, const bool b_want_rotated,
- const size_t c_one, const size_t c_two, const bool c_want_rotated) {
+ const size_t a_one, const size_t a_two,
+ const size_t b_one, const size_t b_two,
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) {
+
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(m, db_["MWG"]);
const auto n_ceiled = Ceil(n, db_["NWG"]);
@@ -149,39 +126,39 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
// whether the matrices need to be rotated or not for the kernel.
- const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled;
- const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled;
- const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled;
- const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled;
- const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
- const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
+ size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
+ CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
+ a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
// Determines whether or not temporary matrices are needed
- auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 &&
- a_do_transpose == false && a_conjugate == false;
- auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offset == 0 &&
- b_do_transpose == false && b_conjugate == false;
- auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offset == 0 &&
- c_do_transpose == false;
+ auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
+ auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate);
+ auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false);
// Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
- auto temp_size = size_t{0};
auto b_temp_offset = size_t{0};
auto c_temp_offset = size_t{0};
- if (!a_no_temp) { temp_size += a_one_i*a_two_i; }
- if (!b_no_temp) { b_temp_offset = temp_size; temp_size += b_one_i*b_two_i; }
- if (!c_no_temp) { c_temp_offset = temp_size; temp_size += c_one_i*c_two_i; }
+ const auto temp_size = ComputeTempSize(a_no_temp, b_no_temp, c_no_temp,
+ a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i,
+ b_temp_offset, c_temp_offset);
if (!IsMultiple(b_temp_offset, db_["VWN"])) { throw BLASError(StatusCode::kUnexpectedError); }
if (!IsMultiple(c_temp_offset, db_["VWM"])) { throw BLASError(StatusCode::kUnexpectedError); }
// Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case
// when no temporary buffer is needed, but that's just to make it compile: it is never used.
- const auto temp_buffer = (temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer;
+ const auto temp_buffer_all = (temp_buffer_provided) ? temp_buffer :
+ ((temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer);
+
+ // Verifies if the provided temporary buffer is large enough
+ if (temp_buffer_provided) {
+ const auto required_size = temp_size * sizeof(T);
+ if (temp_buffer_all.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryTemp); }
+ }
// Sets the buffer pointers for (temp) matrices A, B, and C
- const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer;
- const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer;
- const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer;
+ const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer_all;
+ const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer_all;
+ const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer_all;
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();