diff options
Diffstat (limited to 'src/routines/level3/xgemm.hpp')
-rw-r--r-- | src/routines/level3/xgemm.hpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index dab195c5..f0911d6a 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -39,6 +39,45 @@ class Xgemm: public Routine { return (m_n_k < min_indirect_size_e3); } + // Process the user-arguments, computes secondary parameters + static void ProcessArguments(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + size_t& a_one, size_t& a_two, size_t& b_one, + size_t& b_two, size_t& c_one, size_t& c_two, + bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose, + bool& a_conjugate, bool& b_conjugate) { + + // Makes sure all dimensions are larger than zero + if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Computes whether or not the matrices are transposed in memory. This is based on their layout + // (row or column-major) and whether or not they are requested to be pre-transposed. Note + // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of + // col-major) to be transformed, so transposing requirements are not the same as whether or not + // the matrix is actually transposed in memory. + const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && a_transpose == Transpose::kNo); + const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); + const auto c_rotated = (layout == Layout::kRowMajor); + a_do_transpose = a_rotated != a_want_rotated_; + b_do_transpose = b_rotated != b_want_rotated_; + c_do_transpose = c_rotated != c_want_rotated_; + + // In case of complex data-types, the transpose can also become a conjugate transpose + a_conjugate = (a_transpose == Transpose::kConjugate); + b_conjugate = (b_transpose == Transpose::kConjugate); + + // Computes the first and second dimensions of the 3 matrices taking into account whether the + // matrices are rotated or not + a_one = (a_rotated) ? k : m; + a_two = (a_rotated) ? m : k; + b_one = (b_rotated) ? n : k; + b_two = (b_rotated) ? k : n; + c_one = (c_rotated) ? n : m; + c_two = (c_rotated) ? m : n; + } + // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices static size_t ComputeTempSize(const bool a_no_temp, const bool b_no_temp, const bool c_no_temp, const size_t a_size, const size_t b_size, const size_t c_size, |