summaryrefslogtreecommitdiff
path: root/src/routines/level3/xgemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/level3/xgemm.cpp')
-rw-r--r--src/routines/level3/xgemm.cpp14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index ee33c8be..143ef3c1 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -275,9 +275,11 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
// Loads the program from the database
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
- // Retrieves the XgemmDirect kernel from the compiled binary
+ // Retrieves the proper XgemmDirect kernel from the compiled binary
try {
- auto kernel = Kernel(program, "XgemmDirect");
+ const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") :
+ (b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN");
+ auto kernel = Kernel(program, name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m));
@@ -294,11 +296,9 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
kernel.SetArgument(11, c_buffer());
kernel.SetArgument(12, static_cast<int>(c_offset));
kernel.SetArgument(13, static_cast<int>(c_ld));
- kernel.SetArgument(14, static_cast<int>(a_do_transpose));
- kernel.SetArgument(15, static_cast<int>(b_do_transpose));
- kernel.SetArgument(16, static_cast<int>(c_do_transpose));
- kernel.SetArgument(17, static_cast<int>(a_conjugate));
- kernel.SetArgument(18, static_cast<int>(b_conjugate));
+ kernel.SetArgument(14, static_cast<int>(c_do_transpose));
+ kernel.SetArgument(15, static_cast<int>(a_conjugate));
+ kernel.SetArgument(16, static_cast<int>(b_conjugate));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["WGD"]);