diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-10-28 17:32:37 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-10-28 17:32:37 +0200 |
commit | 12b08ae49154379f7471a40809ace6418857b387 (patch) | |
tree | ef958197db0bb8a67c9a5840f828b3f6c72bd8fc /src/routines/level3/xtrsm.cpp | |
parent | 2949e156f5bfdd724987e67477da3e3608e4aaf9 (diff) | |
parent | fa6e5e67f585b77d34c3031c176de9a0f7904aa9 (diff) |
Merge branch 'master' into android_support
Diffstat (limited to 'src/routines/level3/xtrsm.cpp')
-rw-r--r-- | src/routines/level3/xtrsm.cpp | 110 |
1 files changed, 69 insertions, 41 deletions
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp index 685d458b..d622e3bf 100644 --- a/src/routines/level3/xtrsm.cpp +++ b/src/routines/level3/xtrsm.cpp @@ -73,7 +73,7 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle, const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) { // Settings - constexpr auto block_size = size_t{32}; // tuneable + constexpr auto block_size = size_t{16}; // tuneable // Makes sure all dimensions are larger than zero if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -128,18 +128,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle, for (auto i = size_t{0}; i < m; i += block_size) { const auto gemm_alpha = (i == 0) ? alpha : ConstantOne<T>(); const auto current_block_size = std::min(m - i, block_size); - DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, - current_block_size, n, current_block_size, gemm_alpha, - a_inv_buffer, i * block_size, block_size, - b_buffer, b_offset + i, b_ld, ConstantZero<T>(), - x_buffer, x_offset + i, x_ld); + auto gemm1_event = Event(); + auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer()); + gemm1.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, + current_block_size, n, current_block_size, gemm_alpha, + a_inv_buffer, i * block_size, block_size, + b_buffer, b_offset + i, b_ld, ConstantZero<T>(), + x_buffer, x_offset + i, x_ld); + gemm1_event.WaitForCompletion(); if (i + block_size >= m) { break; } + const auto this_a_offset = (a_transpose == Transpose::kNo) ? (i + block_size) + i * a_ld : i + (block_size + i) * a_ld; - DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, - m - i - block_size, n, block_size, ConstantNegOne<T>(), - a_buffer, this_a_offset, a_ld, - x_buffer, x_offset + i, x_ld, gemm_alpha, - b_buffer, b_offset + i + block_size, b_ld); + auto gemm2_event = Event(); + auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer()); + gemm2.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, + m - i - block_size, n, block_size, ConstantNegOne<T>(), + a_buffer, this_a_offset + a_offset, a_ld, + x_buffer, x_offset + i, x_ld, gemm_alpha, + b_buffer, b_offset + i + block_size, b_ld); + gemm2_event.WaitForCompletion(); } } @@ -150,18 +157,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle, for (auto i = i_start; i >= 0; i -= static_cast<int>(block_size)) { const auto current_block_size = (i == i_start) ? special_block_size : block_size; const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne<T>(); - DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, - current_block_size, n, current_block_size, gemm_alpha, - a_inv_buffer, i * block_size, block_size, - b_buffer, b_offset + i, b_ld, ConstantZero<T>(), - x_buffer, x_offset + i, x_ld); + auto gemm1_event = Event(); + auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer()); + gemm1.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, + current_block_size, n, current_block_size, gemm_alpha, + a_inv_buffer, i * block_size, block_size, + b_buffer, b_offset + i, b_ld, ConstantZero<T>(), + x_buffer, x_offset + i, x_ld); + gemm1_event.WaitForCompletion(); if (i - static_cast<int>(block_size) < 0) { break; } + const auto this_a_offset = (a_transpose == Transpose::kNo) ? i * a_ld : i; - DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, - i, n, current_block_size, ConstantNegOne<T>(), - a_buffer, this_a_offset, a_ld, - x_buffer, x_offset + i, x_ld, gemm_alpha, - b_buffer, b_offset, b_ld); + auto gemm2_event = Event(); + auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer()); + gemm2.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo, + i, n, current_block_size, ConstantNegOne<T>(), + a_buffer, this_a_offset + a_offset, a_ld, + x_buffer, x_offset + i, x_ld, gemm_alpha, + b_buffer, b_offset, b_ld); + gemm2_event.WaitForCompletion(); } } } @@ -176,18 +190,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle, for (auto i = i_start; i >= 0; i -= static_cast<int>(block_size)) { const auto current_block_size = (i == i_start) ? special_block_size : block_size; const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne<T>(); - DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, - m, current_block_size, current_block_size, gemm_alpha, - b_buffer, b_offset + i * b_ld, b_ld, - a_inv_buffer, i * block_size, block_size, ConstantZero<T>(), - x_buffer, x_offset + i * x_ld, x_ld); + auto gemm1_event = Event(); + auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer()); + gemm1.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, + m, current_block_size, current_block_size, gemm_alpha, + b_buffer, b_offset + i * b_ld, b_ld, + a_inv_buffer, i * block_size, block_size, ConstantZero<T>(), + x_buffer, x_offset + i * x_ld, x_ld); + gemm1_event.WaitForCompletion(); if (i - static_cast<int>(block_size) < 0) { break; } + const auto this_a_offset = (a_transpose == Transpose::kNo) ? i : i * a_ld; - DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, - m, i, current_block_size, ConstantNegOne<T>(), - x_buffer, x_offset + i * x_ld, x_ld, - a_buffer, this_a_offset, a_ld, gemm_alpha, - b_buffer, b_offset, b_ld); + auto gemm2_event = Event(); + auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer()); + gemm2.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, + m, i, current_block_size, ConstantNegOne<T>(), + x_buffer, x_offset + i * x_ld, x_ld, + a_buffer, this_a_offset + a_offset, a_ld, gemm_alpha, + b_buffer, b_offset, b_ld); + gemm2_event.WaitForCompletion(); } } @@ -196,18 +217,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle, for (auto i = size_t{0}; i < n; i += block_size) { const auto gemm_alpha = (i == 0) ? alpha : ConstantOne<T>(); const auto current_block_size = std::min(n - i, block_size); - DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, - m, current_block_size, current_block_size, gemm_alpha, - b_buffer, b_offset + i * b_ld, b_ld, - a_inv_buffer, i * block_size, block_size, ConstantZero<T>(), - x_buffer, x_offset + i * x_ld, x_ld); + auto gemm1_event = Event(); + auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer()); + gemm1.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, + m, current_block_size, current_block_size, gemm_alpha, + b_buffer, b_offset + i * b_ld, b_ld, + a_inv_buffer, i * block_size, block_size, ConstantZero<T>(), + x_buffer, x_offset + i * x_ld, x_ld); + gemm1_event.WaitForCompletion(); if (i + block_size >= n) { break; } + const auto this_a_offset = (a_transpose == Transpose::kNo) ? i + (block_size + i) * a_ld : (i + block_size) + i * a_ld; - DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, - m, n - i - block_size, block_size, ConstantNegOne<T>(), - x_buffer, x_offset + i * x_ld, x_ld, - a_buffer, this_a_offset, a_ld, gemm_alpha, - b_buffer, b_offset + (i + block_size) * b_ld, b_ld); + auto gemm2_event = Event(); + auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer()); + gemm2.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose, + m, n - i - block_size, block_size, ConstantNegOne<T>(), + x_buffer, x_offset + i * x_ld, x_ld, + a_buffer, this_a_offset + a_offset, a_ld, gemm_alpha, + b_buffer, b_offset + (i + block_size) * b_ld, b_ld); + gemm2_event.WaitForCompletion(); } } } |