diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-04-13 22:09:16 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-04-13 22:27:11 +0200 |
commit | 0dff7f1ac43bb7d074db36ce2ce44c52e6760e7e (patch) | |
tree | e9c33e6067bd6dbc9fdb8306df5d14d45fab3ad4 /src/routines | |
parent | 0f49dd24e5307e52d748654aca303f15fa629b36 (diff) |
Made GEMM rotation expectations kernel-specific
Diffstat (limited to 'src/routines')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 15 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 39 | ||||
-rw-r--r-- | src/routines/levelx/xgemmbatched.cpp | 6 | ||||
-rw-r--r-- | src/routines/levelx/xgemmstridedbatched.cpp | 6 |
4 files changed, 36 insertions, 30 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 4c1b9558..6a314280 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -19,11 +19,6 @@ namespace clblast { // ================================================================================================= -// Defines the assumptions of the GEMM kernels -template <typename T> const bool Xgemm<T>::a_want_rotated_ = false; -template <typename T> const bool Xgemm<T>::b_want_rotated_ = true; -template <typename T> const bool Xgemm<T>::c_want_rotated_ = false; - // Constructor: forwards to base class constructor template <typename T> Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): @@ -69,7 +64,8 @@ void Xgemm<T>::DoGemm(const Layout layout, size_t a_one, a_two, b_one, b_two, c_one, c_two; ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + db_["GEMMK"]); // Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and // their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the @@ -122,13 +118,14 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(m, db_["MWG"]); const auto n_ceiled = Ceil(n, db_["NWG"]); - const auto k_ceiled = Ceil(k, db_["KWG"]); + const auto k_ceiled = Ceil(k, db_["KWG"] * db_["KREG"]); // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account // whether the matrices need to be rotated or not for the kernel. size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; - CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], - a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"] * db_["KREG"], + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i, + db_["GEMMK"]); // Determines whether or not temporary matrices are needed auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate); diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index b51d1771..ec84fbb7 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -25,9 +25,9 @@ class Xgemm: public Routine { public: // Defines the assumptions of the GEMM kernels - static const bool a_want_rotated_; - static const bool b_want_rotated_; - static const bool c_want_rotated_; + static const bool a_want_rotated_(const size_t gemm_kernel_id) { return gemm_kernel_id == 1; } + static const bool b_want_rotated_(const size_t gemm_kernel_id) { return true; } + static const bool c_want_rotated_(const size_t gemm_kernel_id) { return gemm_kernel_id == 1; } // Computes the size of the temporary GEMM buffer based on user-arguments static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, @@ -35,20 +35,23 @@ class Xgemm: public Routine { const size_t a_offset, const size_t a_ld, const size_t b_offset, const size_t b_ld, const size_t c_offset, const size_t c_ld, - const size_t mwg, const size_t nwg, const size_t kwg) { + const size_t mwg, const size_t nwg, const size_t kwg, + const size_t gemm_kernel_id) { // Computes the transpose/conjugate options and sets the a/b/c sizes based on that bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; size_t a_one, a_two, b_one, b_two, c_one, c_two; ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + gemm_kernel_id); // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account // whether the matrices need to be rotated or not for the kernel. size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; CalculateInternalDimensions(m, n, k, mwg, nwg, kwg, - a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i, + gemm_kernel_id); // Determines whether or not temporary matrices are needed auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate); @@ -79,7 +82,8 @@ class Xgemm: public Routine { size_t& a_one, size_t& a_two, size_t& b_one, size_t& b_two, size_t& c_one, size_t& c_two, bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose, - bool& a_conjugate, bool& b_conjugate) { + bool& a_conjugate, bool& b_conjugate, + const size_t gemm_kernel_id) { // Makes sure all dimensions are larger than zero if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -94,9 +98,9 @@ class Xgemm: public Routine { const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); const auto c_rotated = (layout == Layout::kRowMajor); - a_do_transpose = a_rotated != a_want_rotated_; - b_do_transpose = b_rotated != b_want_rotated_; - c_do_transpose = c_rotated != c_want_rotated_; + a_do_transpose = a_rotated != a_want_rotated_(gemm_kernel_id); + b_do_transpose = b_rotated != b_want_rotated_(gemm_kernel_id); + c_do_transpose = c_rotated != c_want_rotated_(gemm_kernel_id); // In case of complex data-types, the transpose can also become a conjugate transpose a_conjugate = (a_transpose == Transpose::kConjugate); @@ -136,16 +140,17 @@ class Xgemm: public Routine { static void CalculateInternalDimensions(const size_t m, const size_t n, const size_t k, const size_t mwg, const size_t nwg, const size_t kwg, size_t& a_one_i, size_t& a_two_i, size_t& b_one_i, - size_t& b_two_i, size_t& c_one_i, size_t& c_two_i) { + size_t& b_two_i, size_t& c_one_i, size_t& c_two_i, + const size_t gemm_kernel_id) { const auto m_ceiled = Ceil(m, mwg); const auto n_ceiled = Ceil(n, nwg); const auto k_ceiled = Ceil(k, kwg); - a_one_i = (a_want_rotated_) ? k_ceiled : m_ceiled; - a_two_i = (a_want_rotated_) ? m_ceiled : k_ceiled; - b_one_i = (b_want_rotated_) ? n_ceiled : k_ceiled; - b_two_i = (b_want_rotated_) ? k_ceiled : n_ceiled; - c_one_i = (c_want_rotated_) ? n_ceiled : m_ceiled; - c_two_i = (c_want_rotated_) ? m_ceiled : n_ceiled; + a_one_i = (a_want_rotated_(gemm_kernel_id)) ? k_ceiled : m_ceiled; + a_two_i = (a_want_rotated_(gemm_kernel_id)) ? m_ceiled : k_ceiled; + b_one_i = (b_want_rotated_(gemm_kernel_id)) ? n_ceiled : k_ceiled; + b_two_i = (b_want_rotated_(gemm_kernel_id)) ? k_ceiled : n_ceiled; + c_one_i = (c_want_rotated_(gemm_kernel_id)) ? n_ceiled : m_ceiled; + c_two_i = (c_want_rotated_(gemm_kernel_id)) ? m_ceiled : n_ceiled; } // Constructor diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index f5ce83e7..2caf1e39 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -70,7 +70,8 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans size_t a_one, a_two, b_one, b_two, c_one, c_two; Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + db_["GEMMK"]); // Tests the matrices for validity for (auto batch = size_t{0}; batch < batch_count; ++batch) { @@ -141,7 +142,8 @@ void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const // whether the matrices need to be rotated or not for the kernel. size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], - a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i, + db_["GEMMK"]); // Sets the "internal" offsets, i.e. the perfect offsets auto a_offsets_i = std::vector<int>(batch_count); diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index 48383cbd..8408f75a 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -66,7 +66,8 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra size_t a_one, a_two, b_one, b_two, c_one, c_two; Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + db_["GEMMK"]); // Tests the matrices for validity for (auto batch = size_t{0}; batch < batch_count; ++batch) { @@ -122,7 +123,8 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, // whether the matrices need to be rotated or not for the kernel. size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], - a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i, + db_["GEMMK"]); // Determines whether or not temporary matrices are needed auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && !a_do_transpose && !a_conjugate; |