diff options
Diffstat (limited to 'src/routines/level3/xgemm.cpp')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 0015b629..7bd388c1 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -150,9 +150,6 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled; const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled; - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_); - // Determines whether or not temporary matrices are needed auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 && a_do_transpose == false && a_conjugate == false; @@ -178,7 +175,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, a_one, a_two, a_ld, a_offset, a_buffer, a_one_i, a_two_i, a_one_i, 0, a_temp, - ConstantOne<T>(), program, + ConstantOne<T>(), program_, true, a_do_transpose, a_conjugate); eventWaitList.push_back(eventProcessA); } @@ -189,7 +186,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList, b_one, b_two, b_ld, b_offset, b_buffer, b_one_i, b_two_i, b_one_i, 0, b_temp, - ConstantOne<T>(), program, + ConstantOne<T>(), program_, true, b_do_transpose, b_conjugate); eventWaitList.push_back(eventProcessB); } @@ -200,13 +197,13 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, c_one, c_two, c_ld, c_offset, c_buffer, c_one_i, c_two_i, c_one_i, 0, c_temp, - ConstantOne<T>(), program, + ConstantOne<T>(), program_, true, c_do_transpose, false); eventWaitList.push_back(eventProcessC); } // Retrieves the Xgemm kernel from the compiled binary - auto kernel = Kernel(program, "Xgemm"); + auto kernel = Kernel(program_, "Xgemm"); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(m_ceiled)); @@ -236,7 +233,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList, c_one_i, c_two_i, c_one_i, 0, c_temp, c_one, c_two, c_ld, c_offset, c_buffer, - ConstantOne<T>(), program, + ConstantOne<T>(), program_, false, c_do_transpose, false); } } @@ -255,13 +252,10 @@ void Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k, const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate) { - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_); - // Retrieves the proper XgemmDirect kernel from the compiled binary const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") : (b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN"); - auto kernel = Kernel(program, name); + auto kernel = Kernel(program_, name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(m)); |