diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-10-02 17:59:05 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-10-02 17:59:05 +0200 |
commit | d8827e908cd7ff70e1bf294468c12e76c749317e (patch) | |
tree | 9122743c2e1b4c2d122d76805b3fd4163c500d7c /src/routines/level3 | |
parent | 61f489e370c56075e166caff6d1ad671ca6787b9 (diff) |
Specialised the GEMM direct kernel in four ways for transposing/non-transposing: NN, NT, TN, TT
Diffstat (limited to 'src/routines/level3')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index ee33c8be..143ef3c1 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -275,9 +275,11 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k, // Loads the program from the database const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_); - // Retrieves the XgemmDirect kernel from the compiled binary + // Retrieves the proper XgemmDirect kernel from the compiled binary try { - auto kernel = Kernel(program, "XgemmDirect"); + const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") : + (b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN"); + auto kernel = Kernel(program, name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(m)); @@ -294,11 +296,9 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k, kernel.SetArgument(11, c_buffer()); kernel.SetArgument(12, static_cast<int>(c_offset)); kernel.SetArgument(13, static_cast<int>(c_ld)); - kernel.SetArgument(14, static_cast<int>(a_do_transpose)); - kernel.SetArgument(15, static_cast<int>(b_do_transpose)); - kernel.SetArgument(16, static_cast<int>(c_do_transpose)); - kernel.SetArgument(17, static_cast<int>(a_conjugate)); - kernel.SetArgument(18, static_cast<int>(b_conjugate)); + kernel.SetArgument(14, static_cast<int>(c_do_transpose)); + kernel.SetArgument(15, static_cast<int>(a_conjugate)); + kernel.SetArgument(16, static_cast<int>(b_conjugate)); // Computes the global and local thread sizes const auto m_ceiled = Ceil(m, db_["WGD"]); |