summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-28 13:56:18 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-28 13:56:18 +0100
commitaaea9474a154a9f07534523e4ca66e4b2c5f2d4f (patch)
tree5293c909ac634eff502f1c8704d725af84576d05 /src
parent74792ce96c828fdb4962fa5fc3192178b2a9386b (diff)
Factored out argument processing from the GEMM routine
Diffstat (limited to 'src')
-rw-r--r--src/routines/level3/xgemm.cpp35
-rw-r--r--src/routines/level3/xgemm.hpp39
2 files changed, 45 insertions, 29 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 4b8c465f..1fe10462 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -63,35 +63,12 @@ void Xgemm<T>::DoGemm(const Layout layout,
const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
- // Makes sure all dimensions are larger than zero
- if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
-
- // Computes whether or not the matrices are transposed in memory. This is based on their layout
- // (row or column-major) and whether or not they are requested to be pre-transposed. Note
- // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of
- // col-major) to be transformed, so transposing requirements are not the same as whether or not
- // the matrix is actually transposed in memory.
- const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
- const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
- const auto c_rotated = (layout == Layout::kRowMajor);
- const auto a_do_transpose = a_rotated != a_want_rotated_;
- const auto b_do_transpose = b_rotated != b_want_rotated_;
- const auto c_do_transpose = c_rotated != c_want_rotated_;
-
- // In case of complex data-types, the transpose can also become a conjugate transpose
- const auto a_conjugate = (a_transpose == Transpose::kConjugate);
- const auto b_conjugate = (b_transpose == Transpose::kConjugate);
-
- // Computes the first and second dimensions of the 3 matrices taking into account whether the
- // matrices are rotated or not
- const auto a_one = (a_rotated) ? k : m;
- const auto a_two = (a_rotated) ? m : k;
- const auto b_one = (b_rotated) ? n : k;
- const auto b_two = (b_rotated) ? k : n;
- const auto c_one = (c_rotated) ? n : m;
- const auto c_two = (c_rotated) ? m : n;
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
// Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and
// their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index dab195c5..f0911d6a 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -39,6 +39,45 @@ class Xgemm: public Routine {
return (m_n_k < min_indirect_size_e3);
}
+ // Process the user-arguments, computes secondary parameters
+ static void ProcessArguments(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ size_t& a_one, size_t& a_two, size_t& b_one,
+ size_t& b_two, size_t& c_one, size_t& c_two,
+ bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose,
+ bool& a_conjugate, bool& b_conjugate) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes whether or not the matrices are transposed in memory. This is based on their layout
+ // (row or column-major) and whether or not they are requested to be pre-transposed. Note
+ // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of
+ // col-major) to be transformed, so transposing requirements are not the same as whether or not
+ // the matrix is actually transposed in memory.
+ const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
+ const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
+ const auto c_rotated = (layout == Layout::kRowMajor);
+ a_do_transpose = a_rotated != a_want_rotated_;
+ b_do_transpose = b_rotated != b_want_rotated_;
+ c_do_transpose = c_rotated != c_want_rotated_;
+
+ // In case of complex data-types, the transpose can also become a conjugate transpose
+ a_conjugate = (a_transpose == Transpose::kConjugate);
+ b_conjugate = (b_transpose == Transpose::kConjugate);
+
+ // Computes the first and second dimensions of the 3 matrices taking into account whether the
+ // matrices are rotated or not
+ a_one = (a_rotated) ? k : m;
+ a_two = (a_rotated) ? m : k;
+ b_one = (b_rotated) ? n : k;
+ b_two = (b_rotated) ? k : n;
+ c_one = (c_rotated) ? n : m;
+ c_two = (c_rotated) ? m : n;
+ }
+
// Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
static size_t ComputeTempSize(const bool a_no_temp, const bool b_no_temp, const bool c_no_temp,
const size_t a_size, const size_t b_size, const size_t c_size,