From aaea9474a154a9f07534523e4ca66e4b2c5f2d4f Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 28 Dec 2017 13:56:18 +0100 Subject: Factored out argument processing from the GEMM routine --- src/routines/level3/xgemm.cpp | 35 ++++++----------------------------- src/routines/level3/xgemm.hpp | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 4b8c465f..1fe10462 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -63,35 +63,12 @@ void Xgemm::DoGemm(const Layout layout, const T beta, const Buffer &c_buffer, const size_t c_offset, const size_t c_ld) { - // Makes sure all dimensions are larger than zero - if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } - - // Computes whether or not the matrices are transposed in memory. This is based on their layout - // (row or column-major) and whether or not they are requested to be pre-transposed. Note - // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of - // col-major) to be transformed, so transposing requirements are not the same as whether or not - // the matrix is actually transposed in memory. - const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) || - (layout == Layout::kRowMajor && a_transpose == Transpose::kNo); - const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || - (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); - const auto c_rotated = (layout == Layout::kRowMajor); - const auto a_do_transpose = a_rotated != a_want_rotated_; - const auto b_do_transpose = b_rotated != b_want_rotated_; - const auto c_do_transpose = c_rotated != c_want_rotated_; - - // In case of complex data-types, the transpose can also become a conjugate transpose - const auto a_conjugate = (a_transpose == Transpose::kConjugate); - const auto b_conjugate = (b_transpose == Transpose::kConjugate); - - // Computes the first and second dimensions of the 3 matrices taking into account whether the - // matrices are rotated or not - const auto a_one = (a_rotated) ? k : m; - const auto a_two = (a_rotated) ? m : k; - const auto b_one = (b_rotated) ? n : k; - const auto b_two = (b_rotated) ? k : n; - const auto c_one = (c_rotated) ? n : m; - const auto c_two = (c_rotated) ? m : n; + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that + bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; + size_t a_one, a_two, b_one, b_two, c_one, c_two; + ProcessArguments(layout, a_transpose, b_transpose, m, n, k, + a_one, a_two, b_one, b_two, c_one, c_two, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); // Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and // their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index dab195c5..f0911d6a 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -39,6 +39,45 @@ class Xgemm: public Routine { return (m_n_k < min_indirect_size_e3); } + // Process the user-arguments, computes secondary parameters + static void ProcessArguments(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + size_t& a_one, size_t& a_two, size_t& b_one, + size_t& b_two, size_t& c_one, size_t& c_two, + bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose, + bool& a_conjugate, bool& b_conjugate) { + + // Makes sure all dimensions are larger than zero + if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Computes whether or not the matrices are transposed in memory. This is based on their layout + // (row or column-major) and whether or not they are requested to be pre-transposed. Note + // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of + // col-major) to be transformed, so transposing requirements are not the same as whether or not + // the matrix is actually transposed in memory. + const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && a_transpose == Transpose::kNo); + const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); + const auto c_rotated = (layout == Layout::kRowMajor); + a_do_transpose = a_rotated != a_want_rotated_; + b_do_transpose = b_rotated != b_want_rotated_; + c_do_transpose = c_rotated != c_want_rotated_; + + // In case of complex data-types, the transpose can also become a conjugate transpose + a_conjugate = (a_transpose == Transpose::kConjugate); + b_conjugate = (b_transpose == Transpose::kConjugate); + + // Computes the first and second dimensions of the 3 matrices taking into account whether the + // matrices are rotated or not + a_one = (a_rotated) ? k : m; + a_two = (a_rotated) ? m : k; + b_one = (b_rotated) ? n : k; + b_two = (b_rotated) ? k : n; + c_one = (c_rotated) ? n : m; + c_two = (c_rotated) ? m : n; + } + // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices static size_t ComputeTempSize(const bool a_no_temp, const bool b_no_temp, const bool c_no_temp, const size_t a_size, const size_t b_size, const size_t c_size, -- cgit v1.2.3