From 8400ee3a097952a49371973780b47fcbf63e9a5f Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Mon, 15 May 2017 22:04:55 +0200 Subject: Fixed an TRSM issue caused by incorrect block size calculation --- src/routines/level3/xtrsm.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp index c1c07d10..685d458b 100644 --- a/src/routines/level3/xtrsm.cpp +++ b/src/routines/level3/xtrsm.cpp @@ -145,9 +145,10 @@ void Xtrsm::TrsmColMajor(const Side side, const Triangle triangle, // True when (upper triangular) or (lower triangular and transposed) else { - const auto current_block_size = (m % block_size == 0) ? block_size : (m % block_size); - const auto i_start = static_cast(m) - static_cast(current_block_size); + const auto special_block_size = (m % block_size == 0) ? block_size : (m % block_size); + const auto i_start = static_cast(m) - static_cast(special_block_size); for (auto i = i_start; i >= 0; i -= static_cast(block_size)) { + const auto current_block_size = (i == i_start) ? special_block_size : block_size; const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne(); DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, current_block_size, n, current_block_size, gemm_alpha, @@ -157,7 +158,7 @@ void Xtrsm::TrsmColMajor(const Side side, const Triangle triangle, if (i - static_cast(block_size) < 0) { break; } const auto this_a_offset = (a_transpose == Transpose::kNo) ? i * a_ld : i; DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, - i, n, block_size, ConstantNegOne(), + i, n, current_block_size, ConstantNegOne(), a_buffer, this_a_offset, a_ld, x_buffer, x_offset + i, x_ld, gemm_alpha, b_buffer, b_offset, b_ld); @@ -170,9 +171,10 @@ void Xtrsm::TrsmColMajor(const Side side, const Triangle triangle, // True when (lower triangular) or (upper triangular and transposed) if (condition) { - const auto current_block_size = (n % block_size == 0) ? block_size : (n % block_size); - const auto i_start = static_cast(n) - static_cast(current_block_size); + const auto special_block_size = (n % block_size == 0) ? block_size : (n % block_size); + const auto i_start = static_cast(n) - static_cast(special_block_size); for (auto i = i_start; i >= 0; i -= static_cast(block_size)) { + const auto current_block_size = (i == i_start) ? special_block_size : block_size; const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne(); DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, m, current_block_size, current_block_size, gemm_alpha, @@ -182,7 +184,7 @@ void Xtrsm::TrsmColMajor(const Side side, const Triangle triangle, if (i - static_cast(block_size) < 0) { break; } const auto this_a_offset = (a_transpose == Transpose::kNo) ? i : i * a_ld; DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, - m, i, block_size, ConstantNegOne(), + m, i, current_block_size, ConstantNegOne(), x_buffer, x_offset + i * x_ld, x_ld, a_buffer, this_a_offset, a_ld, gemm_alpha, b_buffer, b_offset, b_ld); -- cgit v1.2.3