summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-06 19:26:11 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-06 19:26:11 +0100
commitf1e3b35541245e9c9561592c24705bb23000498b (patch)
tree8c721372aa56b17d2849c27c3bbec2cb0ce8c872 /src/routines
parentc988c2cdd166ebf6d5b5ec20f445de1a95a65b16 (diff)
Reduced duplicate code in the batched GEMM implementation
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/levelx/xgemmbatched.cpp61
-rw-r--r--src/routines/levelx/xgemmbatched.hpp6
2 files changed, 20 insertions, 47 deletions
diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp
index 8ce2dedc..1c0953e8 100644
--- a/src/routines/levelx/xgemmbatched.cpp
+++ b/src/routines/levelx/xgemmbatched.cpp
@@ -12,6 +12,7 @@
// =================================================================================================
#include "routines/levelx/xgemmbatched.hpp"
+#include "routines/level3/xgemm.hpp"
#include <string>
#include <vector>
@@ -64,34 +65,12 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans
throw BLASError(StatusCode::kInvalidBatchCount);
}
- // Makes sure all dimensions are larger than zero
- if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
-
- // Computes whether or not the matrices are transposed in memory. See GEMM routine for details.
- const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
- const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
- const auto c_rotated = (layout == Layout::kRowMajor);
- static const auto a_want_rotated = false;
- static const auto b_want_rotated = true;
- static const auto c_want_rotated = false;
- const auto a_do_transpose = a_rotated != a_want_rotated;
- const auto b_do_transpose = b_rotated != b_want_rotated;
- const auto c_do_transpose = c_rotated != c_want_rotated;
-
- // In case of complex data-types, the transpose can also become a conjugate transpose
- const auto a_conjugate = (a_transpose == Transpose::kConjugate);
- const auto b_conjugate = (b_transpose == Transpose::kConjugate);
-
- // Computes the first and second dimensions of the 3 matrices taking into account whether the
- // matrices are rotated or not
- const auto a_one = (a_rotated) ? k : m;
- const auto a_two = (a_rotated) ? m : k;
- const auto b_one = (b_rotated) ? n : k;
- const auto b_two = (b_rotated) ? k : n;
- const auto c_one = (c_rotated) ? n : m;
- const auto c_two = (c_rotated) ? m : n;
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
// Tests the matrices for validity
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
@@ -130,10 +109,7 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans
a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld,
betas_device, c_buffer, c_offsets_int, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
- a_one, a_two, a_want_rotated,
- b_one, b_two, b_want_rotated,
- c_one, c_two, c_want_rotated,
- batch_count);
+ a_one, a_two, b_one, b_two, c_one, c_two, batch_count);
}
}
@@ -152,9 +128,9 @@ void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
- const size_t a_one, const size_t a_two, const bool a_want_rotated,
- const size_t b_one, const size_t b_two, const bool b_want_rotated,
- const size_t c_one, const size_t c_two, const bool c_want_rotated,
+ const size_t a_one, const size_t a_two,
+ const size_t b_one, const size_t b_two,
+ const size_t c_one, const size_t c_two,
const size_t batch_count) {
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]);
@@ -163,12 +139,9 @@ void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
// whether the matrices need to be rotated or not for the kernel.
- const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled;
- const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled;
- const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled;
- const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled;
- const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
- const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
+ size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
+ Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
+ a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
// Sets the "internal" offsets, i.e. the perfect offsets
auto a_offsets_i = std::vector<int>(batch_count);
@@ -182,11 +155,11 @@ void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const
// Determines whether or not temporary matrices are needed
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i &&
- a_do_transpose == false && a_conjugate == false;
+ !a_do_transpose && !a_conjugate;
auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i &&
- b_do_transpose == false && b_conjugate == false;
+ !b_do_transpose && !b_conjugate;
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i &&
- c_do_transpose == false;
+ !c_do_transpose;
// Creates the temporary matrices
const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i);
diff --git a/src/routines/levelx/xgemmbatched.hpp b/src/routines/levelx/xgemmbatched.hpp
index 6136dd5f..989f3815 100644
--- a/src/routines/levelx/xgemmbatched.hpp
+++ b/src/routines/levelx/xgemmbatched.hpp
@@ -48,9 +48,9 @@ class XgemmBatched: public Routine {
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
- const size_t a_one, const size_t a_two, const bool a_want_rotated,
- const size_t b_one, const size_t b_two, const bool b_want_rotated,
- const size_t c_one, const size_t c_two, const bool c_want_rotated,
+ const size_t a_one, const size_t a_two,
+ const size_t b_one, const size_t b_two,
+ const size_t c_one, const size_t c_two,
const size_t batch_count);
// Direct version of batched GEMM (no pre and post-processing kernels)