diff options
Diffstat (limited to 'src/routines/level3')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index cb24460a..6daa0fcf 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -216,9 +216,11 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, kernel.SetArgument(9, static_cast<int>(c_temp_offset / db_["VWM"])); // Computes the global and local thread sizes + const auto global_divider_one = c_want_rotated_(db_["GEMMK"]) ? db_["NWG"] : db_["MWG"]; + const auto global_divider_two = c_want_rotated_(db_["GEMMK"]) ? db_["MWG"] : db_["NWG"]; const auto global = std::vector<size_t>{ - (c_one_i * db_["MDIMC"]) / db_["MWG"], - (c_two_i * db_["NDIMC"]) / db_["NWG"] + (c_one_i * db_["MDIMC"]) / global_divider_one, + (c_two_i * db_["NDIMC"]) / global_divider_two }; const auto local = std::vector<size_t>{db_["MDIMC"], db_["NDIMC"]}; |