diff options
-rw-r--r-- | src/routines/level3/xtrsm.cpp | 23 | ||||
-rw-r--r-- | src/routines/levelx/xinvert.cpp | 6 |
2 files changed, 13 insertions, 16 deletions
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp index 3a910261..8fe33b64 100644 --- a/src/routines/level3/xtrsm.cpp +++ b/src/routines/level3/xtrsm.cpp @@ -54,11 +54,6 @@ void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle trian // Checks for validity of the triangular A matrix TestMatrixA(k, k, a_buffer, a_offset, a_ld); - // Determines which kernels to run based on the layout (the kernels assume column-major as - // default) and on whether we are dealing with an upper or lower triangle of the triangular matrix - const bool is_upper = ((triangle == Triangle::kUpper && layout != Layout::kRowMajor) || - (triangle == Triangle::kLower && layout == Layout::kRowMajor)); - // Checks for validity of the input B matrix const auto b_one = (layout == Layout::kRowMajor) ? n : m; const auto b_two = (layout == Layout::kRowMajor) ? m : n; @@ -91,15 +86,17 @@ void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle trian k, block_size, a_buffer, a_offset, a_ld, a_inv_buffer); diagonal_invert_event.WaitForCompletion(); - // Lower of upper triangular - const bool condition = ((triangle == Triangle::kUpper && a_transpose != Transpose::kNo) || - (triangle == Triangle::kLower && a_transpose == Transpose::kNo)); + // Derives properties based on the arguments + const auto is_upper = ((triangle == Triangle::kUpper && a_transpose == Transpose::kNo) || + (triangle == Triangle::kLower && a_transpose != Transpose::kNo)); + const auto is_transposed = ((layout == Layout::kColMajor && a_transpose == Transpose::kNo) || + (layout != Layout::kColMajor && a_transpose != Transpose::kNo)); // Left side if (side == Side::kLeft) { // True when (lower triangular) or (upper triangular and transposed) - if (condition) { + if ((!is_upper && !is_transposed) || (is_upper && is_transposed)) { for (auto i = size_t{0}; i < m; i += block_size) { const auto gemm_alpha = (i == 0) ? alpha : ConstantOne<T>(); const auto current_block_size = std::min(m - i, block_size); @@ -125,7 +122,7 @@ void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle trian for (auto i = i_start; i >= 0; i -= static_cast<int>(block_size)) { const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne<T>(); DoGemm(layout, a_transpose, Transpose::kNo, - block_size, n, block_size, gemm_alpha, + current_block_size, n, current_block_size, gemm_alpha, a_inv_buffer, i * block_size, block_size, b_buffer, i, b_ld, ConstantZero<T>(), x_buffer, i, x_ld); @@ -144,20 +141,20 @@ void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle trian else { // True when (lower triangular) or (upper triangular and transposed) - if (condition) { + if ((!is_upper && !is_transposed) || (is_upper && is_transposed)) { const auto current_block_size = (n % block_size == 0) ? block_size : (n % block_size); const auto i_start = static_cast<int>(n) - static_cast<int>(current_block_size); for (auto i = i_start; i >= 0; i -= static_cast<int>(block_size)) { const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne<T>(); DoGemm(layout, Transpose::kNo, a_transpose, - m, block_size, block_size, gemm_alpha, + m, current_block_size, current_block_size, gemm_alpha, b_buffer, i * b_ld, b_ld, a_inv_buffer, i * block_size, block_size, ConstantZero<T>(), x_buffer, i * x_ld, x_ld); if (i - static_cast<int>(block_size) < 0) { break; } const auto this_a_offset = (a_transpose == Transpose::kNo) ? i : i * a_ld; DoGemm(layout, Transpose::kNo, a_transpose, - m, i, block_size, ConstantNegOne<T>(), + m, i, current_block_size, ConstantNegOne<T>(), x_buffer, i * x_ld, x_ld, a_buffer, this_a_offset, a_ld, ConstantOne<T>(), b_buffer, 0, b_ld); diff --git a/src/routines/levelx/xinvert.cpp b/src/routines/levelx/xinvert.cpp index 696e694a..bcc3706d 100644 --- a/src/routines/levelx/xinvert.cpp +++ b/src/routines/levelx/xinvert.cpp @@ -41,8 +41,8 @@ void Xinvert<T>::InvertMatrixDiagonalBlocks(const Layout layout, const Triangle const Buffer<T> &src, const size_t offset, const size_t ld_src, Buffer<T> &dest) { - // Makes sure all dimensions are larger than zero and the block size is smaller than n - if ((block_size == 0) || (n == 0) || (block_size > n)) { + // Makes sure all dimensions are larger than zero + if ((block_size == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -56,7 +56,7 @@ void Xinvert<T>::InvertMatrixDiagonalBlocks(const Layout layout, const Triangle // This routine only supports block sizes which are a multiple of the internal block size and // block sizes up to and including 128 if ((block_size % internal_block_size != 0) || (block_size > 128)) { - throw BLASError(StatusCode::kInvalidDimension); + throw BLASError(StatusCode::kUnknownError); } // Checks for validity of the source and destination matrices |