diff options
Diffstat (limited to 'src/routines/levelx')
-rw-r--r-- | src/routines/levelx/xgemmbatched.cpp | 6 | ||||
-rw-r--r-- | src/routines/levelx/xgemmstridedbatched.cpp | 6 |
2 files changed, 8 insertions, 4 deletions
diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index f5ce83e7..2caf1e39 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -70,7 +70,8 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans size_t a_one, a_two, b_one, b_two, c_one, c_two; Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + db_["GEMMK"]); // Tests the matrices for validity for (auto batch = size_t{0}; batch < batch_count; ++batch) { @@ -141,7 +142,8 @@ void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const // whether the matrices need to be rotated or not for the kernel. size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], - a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i, + db_["GEMMK"]); // Sets the "internal" offsets, i.e. the perfect offsets auto a_offsets_i = std::vector<int>(batch_count); diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index 48383cbd..8408f75a 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -66,7 +66,8 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra size_t a_one, a_two, b_one, b_two, c_one, c_two; Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + db_["GEMMK"]); // Tests the matrices for validity for (auto batch = size_t{0}; batch < batch_count; ++batch) { @@ -122,7 +123,8 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, // whether the matrices need to be rotated or not for the kernel. size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], - a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i, + db_["GEMMK"]); // Determines whether or not temporary matrices are needed auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && !a_do_transpose && !a_conjugate; |