From d8af24e3886b0012b21de8714277facf1b2c7159 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 20 Nov 2016 16:27:54 +0100 Subject: Now correctly tests for validaty of the B matrix in the TRMM routine --- src/routines/level3/xtrmm.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'src/routines/level3') diff --git a/src/routines/level3/xtrmm.cpp b/src/routines/level3/xtrmm.cpp index 1c1f5f90..ed810e72 100644 --- a/src/routines/level3/xtrmm.cpp +++ b/src/routines/level3/xtrmm.cpp @@ -46,6 +46,16 @@ void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle trian // Checks for validity of the triangular A matrix TestMatrixA(k, k, a_buffer, a_offset, a_ld); + // Checks for validity of the input/output B matrix + const auto b_one = (layout == Layout::kRowMajor) ? n : m; + const auto b_two = (layout == Layout::kRowMajor) ? m : n; + TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld); + + // Creates a copy of B to avoid overwriting input in GEMM while computing output + const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset); + auto b_buffer_copy = Buffer(context_, b_size); + b_buffer.CopyTo(queue_, b_size, b_buffer_copy); + // Determines which kernel to run based on the layout (the Xgemm kernel assumes column-major as // default) and on whether we are dealing with an upper or lower triangle of the triangular matrix bool is_upper = ((triangle == Triangle::kUpper && layout != Layout::kRowMajor) || @@ -55,11 +65,6 @@ void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle trian // Determines whether or not the triangular matrix is unit-diagonal auto unit_diagonal = (diagonal == Diagonal::kUnit) ? true : false; - // Creates a copy of B to avoid overwriting input in GEMM while computing output - const auto b_one = (layout == Layout::kRowMajor) ? m : n; - auto b_buffer_copy = Buffer(context_, b_one*b_ld + b_offset); - b_buffer.CopyTo(queue_, b_one*b_ld + b_offset, b_buffer_copy); - // Temporary buffer for a copy of the triangular matrix auto temp_triangular = Buffer(context_, k*k); -- cgit v1.2.3