diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/routines/level3/xtrmm.cpp | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/src/routines/level3/xtrmm.cpp b/src/routines/level3/xtrmm.cpp index 1c1f5f90..ed810e72 100644 --- a/src/routines/level3/xtrmm.cpp +++ b/src/routines/level3/xtrmm.cpp @@ -46,6 +46,16 @@ void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle trian // Checks for validity of the triangular A matrix TestMatrixA(k, k, a_buffer, a_offset, a_ld); + // Checks for validity of the input/output B matrix + const auto b_one = (layout == Layout::kRowMajor) ? n : m; + const auto b_two = (layout == Layout::kRowMajor) ? m : n; + TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld); + + // Creates a copy of B to avoid overwriting input in GEMM while computing output + const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset); + auto b_buffer_copy = Buffer<T>(context_, b_size); + b_buffer.CopyTo(queue_, b_size, b_buffer_copy); + // Determines which kernel to run based on the layout (the Xgemm kernel assumes column-major as // default) and on whether we are dealing with an upper or lower triangle of the triangular matrix bool is_upper = ((triangle == Triangle::kUpper && layout != Layout::kRowMajor) || @@ -55,11 +65,6 @@ void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle trian // Determines whether or not the triangular matrix is unit-diagonal auto unit_diagonal = (diagonal == Diagonal::kUnit) ? true : false; - // Creates a copy of B to avoid overwriting input in GEMM while computing output - const auto b_one = (layout == Layout::kRowMajor) ? m : n; - auto b_buffer_copy = Buffer<T>(context_, b_one*b_ld + b_offset); - b_buffer.CopyTo(queue_, b_one*b_ld + b_offset, b_buffer_copy); - // Temporary buffer for a copy of the triangular matrix auto temp_triangular = Buffer<T>(context_, k*k); |