diff options
Diffstat (limited to 'src/routines/level2/xtrsv.cpp')
-rw-r--r-- | src/routines/level2/xtrsv.cpp | 10 |
1 files changed, 2 insertions, 8 deletions
diff --git a/src/routines/level2/xtrsv.cpp b/src/routines/level2/xtrsv.cpp index b0e4c5ae..d5d009ff 100644 --- a/src/routines/level2/xtrsv.cpp +++ b/src/routines/level2/xtrsv.cpp @@ -37,9 +37,6 @@ void Xtrsv<T>::Substitution(const Layout layout, const Triangle triangle, if (n > db_["TRSV_BLOCK_SIZE"]) { throw BLASError(StatusCode::kUnexpectedError); }; - // Retrieves the program from the cache - const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), "TRSV"); - // Translates CLBlast arguments to 0/1 integers for the OpenCL kernel const auto is_unit_diagonal = (diagonal == Diagonal::kNonUnit) ? 0 : 1; const auto is_transposed = ((a_transpose == Transpose::kNo && layout == Layout::kColMajor) || @@ -52,7 +49,7 @@ void Xtrsv<T>::Substitution(const Layout layout, const Triangle triangle, // Retrieves the kernel from the compiled binary const auto kernel_name = (is_upper) ? "trsv_backward" : "trsv_forward"; - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(n)); @@ -94,9 +91,6 @@ void Xtrsv<T>::DoTrsv(const Layout layout, const Triangle triangle, TestMatrixA(n, n, a_buffer, a_offset, a_ld); TestVectorX(n, b_buffer, b_offset, b_inc); - // Retrieves the program from the cache - const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), "TRSV"); - // Creates a copy of B to avoid overwriting input while computing output // TODO: Make x with 0 offset and unit increment by creating custom copy-to and copy-from kernels const auto x_offset = b_offset; @@ -108,7 +102,7 @@ void Xtrsv<T>::DoTrsv(const Layout layout, const Triangle triangle, // Fills the output buffer with zeros auto eventWaitList = std::vector<Event>(); auto fill_vector_event = Event(); - FillVector(queue_, device_, program, db_, fill_vector_event.pointer(), eventWaitList, + FillVector(queue_, device_, program_, db_, fill_vector_event.pointer(), eventWaitList, n, x_inc, x_offset, x_buffer, ConstantZero<T>()); fill_vector_event.WaitForCompletion(); |