summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/common.hpp9
-rw-r--r--src/routines/level2/xtrsv.cpp13
-rw-r--r--src/routines/level2/xtrsv.hpp4
-rw-r--r--src/routines/level3/xgemm.cpp1
-rw-r--r--src/routines/level3/xherk.cpp1
-rw-r--r--src/routines/level3/xsyrk.cpp1
-rw-r--r--src/routines/level3/xtrsm.cpp2
-rw-r--r--src/routines/level3/xtrsm.hpp1
-rw-r--r--src/routines/levelx/xgemmbatched.cpp1
-rw-r--r--src/routines/levelx/xgemmstridedbatched.cpp1
10 files changed, 23 insertions, 11 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index c30a2e0e..c6db0152 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -76,6 +76,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
// Determines the right kernel
auto kernel_name = std::string{};
+ auto pad_kernel = false;
if (do_transpose) {
if (use_fast_kernel &&
IsMultiple(src_ld, db["TRA_WPT"]) &&
@@ -85,7 +86,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
}
else {
use_fast_kernel = false;
- kernel_name = (do_pad) ? "TransposePadMatrix" : "TransposeMatrix";
+ pad_kernel = (do_pad || do_conjugate);
+ kernel_name = (pad_kernel) ? "TransposePadMatrix" : "TransposeMatrix";
}
}
else {
@@ -97,7 +99,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
}
else {
use_fast_kernel = false;
- kernel_name = (do_pad) ? "CopyPadMatrix" : "CopyMatrix";
+ pad_kernel = do_pad;
+ kernel_name = (pad_kernel) ? "CopyPadMatrix" : "CopyMatrix";
}
}
@@ -123,7 +126,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
kernel.SetArgument(8, static_cast<int>(dest_offset));
kernel.SetArgument(9, dest());
kernel.SetArgument(10, GetRealArg(alpha));
- if (do_pad) {
+ if (pad_kernel) {
kernel.SetArgument(11, static_cast<int>(do_conjugate));
}
else {
diff --git a/src/routines/level2/xtrsv.cpp b/src/routines/level2/xtrsv.cpp
index 76401753..2a5a5664 100644
--- a/src/routines/level2/xtrsv.cpp
+++ b/src/routines/level2/xtrsv.cpp
@@ -33,7 +33,8 @@ void Xtrsv<T>::Substitution(const Layout layout, const Triangle triangle,
const size_t n,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_inc,
- const Buffer<T> &x_buffer, const size_t x_offset, const size_t x_inc) {
+ const Buffer<T> &x_buffer, const size_t x_offset, const size_t x_inc,
+ EventPointer event) {
if (n > db_["TRSV_BLOCK_SIZE"]) { throw BLASError(StatusCode::kUnexpectedError); };
@@ -69,9 +70,7 @@ void Xtrsv<T>::Substitution(const Layout layout, const Triangle triangle,
// Launches the kernel
const auto local = std::vector<size_t>{db_["TRSV_BLOCK_SIZE"]};
const auto global = std::vector<size_t>{Ceil(n, db_["TRSV_BLOCK_SIZE"])};
- auto event = Event();
- RunKernel(kernel, queue_, device_, global, local, event.pointer());
- event.WaitForCompletion();
+ RunKernel(kernel, queue_, device_, global, local, event);
}
// =================================================================================================
@@ -146,14 +145,16 @@ void Xtrsv<T>::DoTrsv(const Layout layout, const Triangle triangle,
}
// Runs the triangular substitution for the block size
+ auto sub_event = Event();
Substitution(layout, triangle, a_transpose, diagonal, block_size,
a_buffer, a_offset + col + col*a_ld, a_ld,
b_buffer, b_offset + col*b_inc, b_inc,
- x_buffer, x_offset + col*x_inc, x_inc);
+ x_buffer, x_offset + col*x_inc, x_inc, sub_event.pointer());
+ sub_event.WaitForCompletion();
}
// Retrieves the results
- x_buffer.CopyTo(queue_, x_size, b_buffer);
+ x_buffer.CopyToAsync(queue_, x_size, b_buffer, event_);
}
// =================================================================================================
diff --git a/src/routines/level2/xtrsv.hpp b/src/routines/level2/xtrsv.hpp
index 67e626a1..8a900a35 100644
--- a/src/routines/level2/xtrsv.hpp
+++ b/src/routines/level2/xtrsv.hpp
@@ -32,6 +32,7 @@ class Xtrsv: public Xgemv<T> {
using Xgemv<T>::device_;
using Xgemv<T>::db_;
using Xgemv<T>::program_;
+ using Xgemv<T>::event_;
using Xgemv<T>::DoGemv;
// Constructor
@@ -50,7 +51,8 @@ class Xtrsv: public Xgemv<T> {
const size_t n,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_inc,
- const Buffer<T> &x_buffer, const size_t offset_x, const size_t x_inc);
+ const Buffer<T> &x_buffer, const size_t offset_x, const size_t x_inc,
+ EventPointer event);
};
// =================================================================================================
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index fd5a20db..cb24460a 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -40,6 +40,7 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part3.opencl"
#include "../../kernels/level3/xgemm_part4.opencl"
}) {
diff --git a/src/routines/level3/xherk.cpp b/src/routines/level3/xherk.cpp
index 6912d3a9..2e6f30ec 100644
--- a/src/routines/level3/xherk.cpp
+++ b/src/routines/level3/xherk.cpp
@@ -32,6 +32,7 @@ Xherk<T,U>::Xherk(Queue &queue, EventPointer event, const std::string &name):
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part3.opencl"
#include "../../kernels/level3/xgemm_part4.opencl"
}) {
diff --git a/src/routines/level3/xsyrk.cpp b/src/routines/level3/xsyrk.cpp
index 6bb2a24f..5ffdc028 100644
--- a/src/routines/level3/xsyrk.cpp
+++ b/src/routines/level3/xsyrk.cpp
@@ -32,6 +32,7 @@ Xsyrk<T>::Xsyrk(Queue &queue, EventPointer event, const std::string &name):
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part3.opencl"
#include "../../kernels/level3/xgemm_part4.opencl"
}) {
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp
index 905660ff..fe5d1e14 100644
--- a/src/routines/level3/xtrsm.cpp
+++ b/src/routines/level3/xtrsm.cpp
@@ -246,7 +246,7 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle,
}
// Retrieves the results
- x_buffer.CopyTo(queue_, b_size, b_buffer);
+ x_buffer.CopyToAsync(queue_, b_size, b_buffer, event_);
}
// =================================================================================================
diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp
index 5b42398e..871d7253 100644
--- a/src/routines/level3/xtrsm.hpp
+++ b/src/routines/level3/xtrsm.hpp
@@ -31,6 +31,7 @@ class Xtrsm: public Xgemm<T> {
using Xgemm<T>::device_;
using Xgemm<T>::db_;
using Xgemm<T>::program_;
+ using Xgemm<T>::event_;
using Xgemm<T>::DoGemm;
// Constructor
diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp
index 2bbc5007..b12b8734 100644
--- a/src/routines/levelx/xgemmbatched.cpp
+++ b/src/routines/levelx/xgemmbatched.cpp
@@ -38,6 +38,7 @@ XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::strin
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part3.opencl"
#include "../../kernels/level3/xgemm_part4.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp
index 30c161cc..d9e3ebba 100644
--- a/src/routines/levelx/xgemmstridedbatched.cpp
+++ b/src/routines/levelx/xgemmstridedbatched.cpp
@@ -37,6 +37,7 @@ XgemmStridedBatched<T>::XgemmStridedBatched(Queue &queue, EventPointer event, co
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part3.opencl"
#include "../../kernels/level3/xgemm_part4.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013