summaryrefslogtreecommitdiff
path: root/src/routines/level3/xgemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/level3/xgemm.hpp')
-rw-r--r--src/routines/level3/xgemm.hpp39
1 files changed, 39 insertions, 0 deletions
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index dab195c5..f0911d6a 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -39,6 +39,45 @@ class Xgemm: public Routine {
return (m_n_k < min_indirect_size_e3);
}
+ // Process the user-arguments, computes secondary parameters
+ static void ProcessArguments(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ size_t& a_one, size_t& a_two, size_t& b_one,
+ size_t& b_two, size_t& c_one, size_t& c_two,
+ bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose,
+ bool& a_conjugate, bool& b_conjugate) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes whether or not the matrices are transposed in memory. This is based on their layout
+ // (row or column-major) and whether or not they are requested to be pre-transposed. Note
+ // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of
+ // col-major) to be transformed, so transposing requirements are not the same as whether or not
+ // the matrix is actually transposed in memory.
+ const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
+ const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
+ const auto c_rotated = (layout == Layout::kRowMajor);
+ a_do_transpose = a_rotated != a_want_rotated_;
+ b_do_transpose = b_rotated != b_want_rotated_;
+ c_do_transpose = c_rotated != c_want_rotated_;
+
+ // In case of complex data-types, the transpose can also become a conjugate transpose
+ a_conjugate = (a_transpose == Transpose::kConjugate);
+ b_conjugate = (b_transpose == Transpose::kConjugate);
+
+ // Computes the first and second dimensions of the 3 matrices taking into account whether the
+ // matrices are rotated or not
+ a_one = (a_rotated) ? k : m;
+ a_two = (a_rotated) ? m : k;
+ b_one = (b_rotated) ? n : k;
+ b_two = (b_rotated) ? k : n;
+ c_one = (c_rotated) ? n : m;
+ c_two = (c_rotated) ? m : n;
+ }
+
// Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
static size_t ComputeTempSize(const bool a_no_temp, const bool b_no_temp, const bool c_no_temp,
const size_t a_size, const size_t b_size, const size_t c_size,