summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-11-20 16:27:54 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-11-20 16:27:54 +0100
commitd8af24e3886b0012b21de8714277facf1b2c7159 (patch)
treea0460887c3adb39a6180fae7bbe55e3c1fce3b27 /src/routines
parent90eb8738c47a95ce0d141582bfba9d83ed8d6e50 (diff)
Now correctly tests for validaty of the B matrix in the TRMM routine
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/level3/xtrmm.cpp15
1 files changed, 10 insertions, 5 deletions
diff --git a/src/routines/level3/xtrmm.cpp b/src/routines/level3/xtrmm.cpp
index 1c1f5f90..ed810e72 100644
--- a/src/routines/level3/xtrmm.cpp
+++ b/src/routines/level3/xtrmm.cpp
@@ -46,6 +46,16 @@ void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle trian
// Checks for validity of the triangular A matrix
TestMatrixA(k, k, a_buffer, a_offset, a_ld);
+ // Checks for validity of the input/output B matrix
+ const auto b_one = (layout == Layout::kRowMajor) ? n : m;
+ const auto b_two = (layout == Layout::kRowMajor) ? m : n;
+ TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
+
+ // Creates a copy of B to avoid overwriting input in GEMM while computing output
+ const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
+ auto b_buffer_copy = Buffer<T>(context_, b_size);
+ b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
+
// Determines which kernel to run based on the layout (the Xgemm kernel assumes column-major as
// default) and on whether we are dealing with an upper or lower triangle of the triangular matrix
bool is_upper = ((triangle == Triangle::kUpper && layout != Layout::kRowMajor) ||
@@ -55,11 +65,6 @@ void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle trian
// Determines whether or not the triangular matrix is unit-diagonal
auto unit_diagonal = (diagonal == Diagonal::kUnit) ? true : false;
- // Creates a copy of B to avoid overwriting input in GEMM while computing output
- const auto b_one = (layout == Layout::kRowMajor) ? m : n;
- auto b_buffer_copy = Buffer<T>(context_, b_one*b_ld + b_offset);
- b_buffer.CopyTo(queue_, b_one*b_ld + b_offset, b_buffer_copy);
-
// Temporary buffer for a copy of the triangular matrix
auto temp_triangular = Buffer<T>(context_, k*k);