diff options
-rw-r--r-- | src/routines/level3/xgemm.cpp | 10 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 6 |
2 files changed, 9 insertions, 7 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index e2e8647e..a85f55b5 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -111,7 +111,9 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, c_buffer, c_offset, c_ld, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - a_one, a_two, b_one, b_two, c_one, c_two); + a_one, a_two, a_want_rotated, + b_one, b_two, b_want_rotated, + c_one, c_two, c_want_rotated); } } @@ -129,9 +131,9 @@ StatusCode Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate, - const size_t a_one, const size_t a_two, - const size_t b_one, const size_t b_two, - const size_t c_one, const size_t c_two) { + const size_t a_one, const size_t a_two, const bool a_want_rotated, + const size_t b_one, const size_t b_two, const bool b_want_rotated, + const size_t c_one, const size_t c_two, const bool c_want_rotated) { auto status = StatusCode::kSuccess; // Calculates the ceiled versions of m, n, and k diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index 8db1cb11..46e12453 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -45,9 +45,9 @@ class Xgemm: public Routine { const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate, - const size_t a_one, const size_t a_two, - const size_t b_one, const size_t b_two, - const size_t c_one, const size_t c_two); + const size_t a_one, const size_t a_two, const bool a_want_rotated, + const size_t b_one, const size_t b_two, const bool b_want_rotated, + const size_t c_one, const size_t c_two, const bool c_want_rotated); // Direct version of GEMM (no pre and post-processing kernels) StatusCode GemmDirect(const size_t m, const size_t n, const size_t k, |