From 2f0697564fea11bd3f91e4474d766de54ca5ac1b Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 20 Nov 2016 15:05:42 +0100 Subject: Fixed a bug in the TRMM routine caused by overwriting input data before consuming everything --- src/routines/level3/xtrmm.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'src/routines/level3') diff --git a/src/routines/level3/xtrmm.cpp b/src/routines/level3/xtrmm.cpp index 6bf77cfa..1c1f5f90 100644 --- a/src/routines/level3/xtrmm.cpp +++ b/src/routines/level3/xtrmm.cpp @@ -30,11 +30,11 @@ Xtrmm::Xtrmm(Queue &queue, EventPointer event, const std::string &name): // The main routine template void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle triangle, - const Transpose a_transpose, const Diagonal diagonal, - const size_t m, const size_t n, - const T alpha, - const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, - const Buffer &b_buffer, const size_t b_offset, const size_t b_ld) { + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer &b_buffer, const size_t b_offset, const size_t b_ld) { // Makes sure all dimensions are larger than zero if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -55,6 +55,11 @@ void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle trian // Determines whether or not the triangular matrix is unit-diagonal auto unit_diagonal = (diagonal == Diagonal::kUnit) ? true : false; + // Creates a copy of B to avoid overwriting input in GEMM while computing output + const auto b_one = (layout == Layout::kRowMajor) ? m : n; + auto b_buffer_copy = Buffer(context_, b_one*b_ld + b_offset); + b_buffer.CopyTo(queue_, b_one*b_ld + b_offset, b_buffer_copy); + // Temporary buffer for a copy of the triangular matrix auto temp_triangular = Buffer(context_, k*k); @@ -91,7 +96,7 @@ void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle trian m, n, k, alpha, temp_triangular, 0, k, - b_buffer, b_offset, b_ld, + b_buffer_copy, b_offset, b_ld, static_cast(0.0), b_buffer, b_offset, b_ld); } @@ -102,7 +107,7 @@ void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle trian DoGemm(layout, Transpose::kNo, a_transpose, m, n, k, alpha, - b_buffer, b_offset, b_ld, + b_buffer_copy, b_offset, b_ld, temp_triangular, 0, k, static_cast(0.0), b_buffer, b_offset, b_ld); -- cgit v1.2.3