summaryrefslogtreecommitdiff
path: root/src/routine.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/routine.cc')
-rw-r--r--src/routine.cc31
1 files changed, 18 insertions, 13 deletions
diff --git a/src/routine.cc b/src/routine.cc
index ff7b3e1a..b5ba63eb 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -26,7 +26,7 @@ template <typename T> std::mutex Routine<T>::program_cache_mutex_;
// Constructor: not much here, because no status codes can be returned
template <typename T>
-Routine<T>::Routine(Queue &queue, Event &event, const std::string &name,
+Routine<T>::Routine(Queue &queue, EventPointer event, const std::string &name,
const std::vector<std::string> &routines, const Precision precision):
precision_(precision),
routine_name_(name),
@@ -117,7 +117,8 @@ StatusCode Routine<T>::SetUp() {
// Enqueues a kernel, waits for completion, and checks for errors
template <typename T>
StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global,
- const std::vector<size_t> &local) {
+ const std::vector<size_t> &local, EventPointer event,
+ std::vector<Event>& waitForEvents) {
// Tests for validity of the local thread sizes
if (local.size() > max_work_item_dimensions_) {
@@ -141,18 +142,21 @@ StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global,
// Launches the kernel (and checks for launch errors)
try {
- kernel.Launch(queue_, global, local, event_);
+ kernel.Launch(queue_, global, local, event, waitForEvents);
} catch (...) { return StatusCode::kKernelLaunchError; }
- // Waits for completion of the kernel
- try {
- queue_.Finish(event_);
- } catch (...) { return StatusCode::kKernelRunError; }
-
// No errors, normal termination of this function
return StatusCode::kSuccess;
}
+// As above, but without an event waiting list
+template <typename T>
+StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global,
+ const std::vector<size_t> &local, EventPointer event) {
+ auto emptyWaitingList = std::vector<Event>();
+ return RunKernel(kernel, global, local, event, emptyWaitingList);
+}
+
// =================================================================================================
// Tests matrix A for validity: checks for a valid OpenCL buffer, a valid lead-dimension, and for a
@@ -258,7 +262,8 @@ StatusCode Routine<T>::TestVectorDot(const size_t n, const Buffer<T> &buffer, co
// Copies or transposes a matrix and pads/unpads it with zeros
template <typename T>
-StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t src_two,
+StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Event>& waitForEvents,
+ const size_t src_one, const size_t src_two,
const size_t src_ld, const size_t src_offset,
const Buffer<T> &src,
const size_t dest_one, const size_t dest_two,
@@ -340,13 +345,13 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t
auto global = std::vector<size_t>{dest_one / db_["TRA_WPT"],
dest_two / db_["TRA_WPT"]};
auto local = std::vector<size_t>{db_["TRA_DIM"], db_["TRA_DIM"]};
- status = RunKernel(kernel, global, local);
+ status = RunKernel(kernel, global, local, event, waitForEvents);
}
else {
auto global = std::vector<size_t>{Ceil(CeilDiv(dest_one, db_["PADTRA_WPT"]), db_["PADTRA_TILE"]),
Ceil(CeilDiv(dest_two, db_["PADTRA_WPT"]), db_["PADTRA_TILE"])};
auto local = std::vector<size_t>{db_["PADTRA_TILE"], db_["PADTRA_TILE"]};
- status = RunKernel(kernel, global, local);
+ status = RunKernel(kernel, global, local, event, waitForEvents);
}
}
else {
@@ -354,13 +359,13 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t
auto global = std::vector<size_t>{dest_one / db_["COPY_VW"],
dest_two / db_["COPY_WPT"]};
auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]};
- status = RunKernel(kernel, global, local);
+ status = RunKernel(kernel, global, local, event, waitForEvents);
}
else {
auto global = std::vector<size_t>{Ceil(CeilDiv(dest_one, db_["PAD_WPTX"]), db_["PAD_DIMX"]),
Ceil(CeilDiv(dest_two, db_["PAD_WPTY"]), db_["PAD_DIMY"])};
auto local = std::vector<size_t>{db_["PAD_DIMX"], db_["PAD_DIMY"]};
- status = RunKernel(kernel, global, local);
+ status = RunKernel(kernel, global, local, event, waitForEvents);
}
}
return status;