summaryrefslogtreecommitdiff
path: root/src/routines/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-10-25 20:34:38 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-10-25 20:35:39 +0200
commitd49aae236eca2dddc3228791ec367cc0c9d3bc7e (patch)
treefabe67d50312f73d77b1e29c9dc8ac0c41452d26 /src/routines/level3
parent42ac3b474844387891cf0a5556fd45b8883d647a (diff)
Fixed a bug in TRSM routine due to missing event synchronisations after GEMM calls
Diffstat (limited to 'src/routines/level3')
-rw-r--r--src/routines/level3/xtrsm.cpp108
1 files changed, 68 insertions, 40 deletions
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp
index 685d458b..61669e44 100644
--- a/src/routines/level3/xtrsm.cpp
+++ b/src/routines/level3/xtrsm.cpp
@@ -128,18 +128,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle,
for (auto i = size_t{0}; i < m; i += block_size) {
const auto gemm_alpha = (i == 0) ? alpha : ConstantOne<T>();
const auto current_block_size = std::min(m - i, block_size);
- DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
- current_block_size, n, current_block_size, gemm_alpha,
- a_inv_buffer, i * block_size, block_size,
- b_buffer, b_offset + i, b_ld, ConstantZero<T>(),
- x_buffer, x_offset + i, x_ld);
+ auto gemm1_event = Event();
+ auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer());
+ gemm1.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
+ current_block_size, n, current_block_size, gemm_alpha,
+ a_inv_buffer, i * block_size, block_size,
+ b_buffer, b_offset + i, b_ld, ConstantZero<T>(),
+ x_buffer, x_offset + i, x_ld);
+ gemm1_event.WaitForCompletion();
if (i + block_size >= m) { break; }
+
const auto this_a_offset = (a_transpose == Transpose::kNo) ? (i + block_size) + i * a_ld : i + (block_size + i) * a_ld;
- DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
- m - i - block_size, n, block_size, ConstantNegOne<T>(),
- a_buffer, this_a_offset, a_ld,
- x_buffer, x_offset + i, x_ld, gemm_alpha,
- b_buffer, b_offset + i + block_size, b_ld);
+ auto gemm2_event = Event();
+ auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer());
+ gemm2.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
+ m - i - block_size, n, block_size, ConstantNegOne<T>(),
+ a_buffer, this_a_offset, a_ld,
+ x_buffer, x_offset + i, x_ld, gemm_alpha,
+ b_buffer, b_offset + i + block_size, b_ld);
+ gemm2_event.WaitForCompletion();
}
}
@@ -150,18 +157,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle,
for (auto i = i_start; i >= 0; i -= static_cast<int>(block_size)) {
const auto current_block_size = (i == i_start) ? special_block_size : block_size;
const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne<T>();
- DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
- current_block_size, n, current_block_size, gemm_alpha,
- a_inv_buffer, i * block_size, block_size,
- b_buffer, b_offset + i, b_ld, ConstantZero<T>(),
- x_buffer, x_offset + i, x_ld);
+ auto gemm1_event = Event();
+ auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer());
+ gemm1.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
+ current_block_size, n, current_block_size, gemm_alpha,
+ a_inv_buffer, i * block_size, block_size,
+ b_buffer, b_offset + i, b_ld, ConstantZero<T>(),
+ x_buffer, x_offset + i, x_ld);
+ gemm1_event.WaitForCompletion();
if (i - static_cast<int>(block_size) < 0) { break; }
+
const auto this_a_offset = (a_transpose == Transpose::kNo) ? i * a_ld : i;
- DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
- i, n, current_block_size, ConstantNegOne<T>(),
- a_buffer, this_a_offset, a_ld,
- x_buffer, x_offset + i, x_ld, gemm_alpha,
- b_buffer, b_offset, b_ld);
+ auto gemm2_event = Event();
+ auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer());
+ gemm2.DoGemm(Layout::kColMajor, a_transpose, Transpose::kNo,
+ i, n, current_block_size, ConstantNegOne<T>(),
+ a_buffer, this_a_offset, a_ld,
+ x_buffer, x_offset + i, x_ld, gemm_alpha,
+ b_buffer, b_offset, b_ld);
+ gemm2_event.WaitForCompletion();
}
}
}
@@ -176,18 +190,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle,
for (auto i = i_start; i >= 0; i -= static_cast<int>(block_size)) {
const auto current_block_size = (i == i_start) ? special_block_size : block_size;
const auto gemm_alpha = (i == i_start) ? alpha : ConstantOne<T>();
- DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
- m, current_block_size, current_block_size, gemm_alpha,
- b_buffer, b_offset + i * b_ld, b_ld,
- a_inv_buffer, i * block_size, block_size, ConstantZero<T>(),
- x_buffer, x_offset + i * x_ld, x_ld);
+ auto gemm1_event = Event();
+ auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer());
+ gemm1.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
+ m, current_block_size, current_block_size, gemm_alpha,
+ b_buffer, b_offset + i * b_ld, b_ld,
+ a_inv_buffer, i * block_size, block_size, ConstantZero<T>(),
+ x_buffer, x_offset + i * x_ld, x_ld);
+ gemm1_event.WaitForCompletion();
if (i - static_cast<int>(block_size) < 0) { break; }
+
const auto this_a_offset = (a_transpose == Transpose::kNo) ? i : i * a_ld;
- DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
- m, i, current_block_size, ConstantNegOne<T>(),
- x_buffer, x_offset + i * x_ld, x_ld,
- a_buffer, this_a_offset, a_ld, gemm_alpha,
- b_buffer, b_offset, b_ld);
+ auto gemm2_event = Event();
+ auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer());
+ gemm2.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
+ m, i, current_block_size, ConstantNegOne<T>(),
+ x_buffer, x_offset + i * x_ld, x_ld,
+ a_buffer, this_a_offset, a_ld, gemm_alpha,
+ b_buffer, b_offset, b_ld);
+ gemm2_event.WaitForCompletion();
}
}
@@ -196,18 +217,25 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle,
for (auto i = size_t{0}; i < n; i += block_size) {
const auto gemm_alpha = (i == 0) ? alpha : ConstantOne<T>();
const auto current_block_size = std::min(n - i, block_size);
- DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
- m, current_block_size, current_block_size, gemm_alpha,
- b_buffer, b_offset + i * b_ld, b_ld,
- a_inv_buffer, i * block_size, block_size, ConstantZero<T>(),
- x_buffer, x_offset + i * x_ld, x_ld);
+ auto gemm1_event = Event();
+ auto gemm1 = Xgemm<T>(queue_, gemm1_event.pointer());
+ gemm1.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
+ m, current_block_size, current_block_size, gemm_alpha,
+ b_buffer, b_offset + i * b_ld, b_ld,
+ a_inv_buffer, i * block_size, block_size, ConstantZero<T>(),
+ x_buffer, x_offset + i * x_ld, x_ld);
+ gemm1_event.WaitForCompletion();
if (i + block_size >= n) { break; }
+
const auto this_a_offset = (a_transpose == Transpose::kNo) ? i + (block_size + i) * a_ld : (i + block_size) + i * a_ld;
- DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
- m, n - i - block_size, block_size, ConstantNegOne<T>(),
- x_buffer, x_offset + i * x_ld, x_ld,
- a_buffer, this_a_offset, a_ld, gemm_alpha,
- b_buffer, b_offset + (i + block_size) * b_ld, b_ld);
+ auto gemm2_event = Event();
+ auto gemm2 = Xgemm<T>(queue_, gemm2_event.pointer());
+ gemm2.DoGemm(Layout::kColMajor, Transpose::kNo, a_transpose,
+ m, n - i - block_size, block_size, ConstantNegOne<T>(),
+ x_buffer, x_offset + i * x_ld, x_ld,
+ a_buffer, this_a_offset, a_ld, gemm_alpha,
+ b_buffer, b_offset + (i + block_size) * b_ld, b_ld);
+ gemm2_event.WaitForCompletion();
}
}
}