summaryrefslogtreecommitdiff
path: root/src/routines/level3/xgemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/level3/xgemm.cpp')
-rw-r--r--src/routines/level3/xgemm.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index cb24460a..6daa0fcf 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -216,9 +216,11 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
kernel.SetArgument(9, static_cast<int>(c_temp_offset / db_["VWM"]));
// Computes the global and local thread sizes
+ const auto global_divider_one = c_want_rotated_(db_["GEMMK"]) ? db_["NWG"] : db_["MWG"];
+ const auto global_divider_two = c_want_rotated_(db_["GEMMK"]) ? db_["MWG"] : db_["NWG"];
const auto global = std::vector<size_t>{
- (c_one_i * db_["MDIMC"]) / db_["MWG"],
- (c_two_i * db_["NDIMC"]) / db_["NWG"]
+ (c_one_i * db_["MDIMC"]) / global_divider_one,
+ (c_two_i * db_["NDIMC"]) / global_divider_two
};
const auto local = std::vector<size_t>{db_["MDIMC"], db_["NDIMC"]};