diff options
Diffstat (limited to 'src/routine.cc')
-rw-r--r-- | src/routine.cc | 31 |
1 files changed, 18 insertions, 13 deletions
diff --git a/src/routine.cc b/src/routine.cc index ff7b3e1a..b5ba63eb 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -26,7 +26,7 @@ template <typename T> std::mutex Routine<T>::program_cache_mutex_; // Constructor: not much here, because no status codes can be returned template <typename T> -Routine<T>::Routine(Queue &queue, Event &event, const std::string &name, +Routine<T>::Routine(Queue &queue, EventPointer event, const std::string &name, const std::vector<std::string> &routines, const Precision precision): precision_(precision), routine_name_(name), @@ -117,7 +117,8 @@ StatusCode Routine<T>::SetUp() { // Enqueues a kernel, waits for completion, and checks for errors template <typename T> StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global, - const std::vector<size_t> &local) { + const std::vector<size_t> &local, EventPointer event, + std::vector<Event>& waitForEvents) { // Tests for validity of the local thread sizes if (local.size() > max_work_item_dimensions_) { @@ -141,18 +142,21 @@ StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global, // Launches the kernel (and checks for launch errors) try { - kernel.Launch(queue_, global, local, event_); + kernel.Launch(queue_, global, local, event, waitForEvents); } catch (...) { return StatusCode::kKernelLaunchError; } - // Waits for completion of the kernel - try { - queue_.Finish(event_); - } catch (...) { return StatusCode::kKernelRunError; } - // No errors, normal termination of this function return StatusCode::kSuccess; } +// As above, but without an event waiting list +template <typename T> +StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global, + const std::vector<size_t> &local, EventPointer event) { + auto emptyWaitingList = std::vector<Event>(); + return RunKernel(kernel, global, local, event, emptyWaitingList); +} + // ================================================================================================= // Tests matrix A for validity: checks for a valid OpenCL buffer, a valid lead-dimension, and for a @@ -258,7 +262,8 @@ StatusCode Routine<T>::TestVectorDot(const size_t n, const Buffer<T> &buffer, co // Copies or transposes a matrix and pads/unpads it with zeros template <typename T> -StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t src_two, +StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Event>& waitForEvents, + const size_t src_one, const size_t src_two, const size_t src_ld, const size_t src_offset, const Buffer<T> &src, const size_t dest_one, const size_t dest_two, @@ -340,13 +345,13 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t auto global = std::vector<size_t>{dest_one / db_["TRA_WPT"], dest_two / db_["TRA_WPT"]}; auto local = std::vector<size_t>{db_["TRA_DIM"], db_["TRA_DIM"]}; - status = RunKernel(kernel, global, local); + status = RunKernel(kernel, global, local, event, waitForEvents); } else { auto global = std::vector<size_t>{Ceil(CeilDiv(dest_one, db_["PADTRA_WPT"]), db_["PADTRA_TILE"]), Ceil(CeilDiv(dest_two, db_["PADTRA_WPT"]), db_["PADTRA_TILE"])}; auto local = std::vector<size_t>{db_["PADTRA_TILE"], db_["PADTRA_TILE"]}; - status = RunKernel(kernel, global, local); + status = RunKernel(kernel, global, local, event, waitForEvents); } } else { @@ -354,13 +359,13 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t auto global = std::vector<size_t>{dest_one / db_["COPY_VW"], dest_two / db_["COPY_WPT"]}; auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]}; - status = RunKernel(kernel, global, local); + status = RunKernel(kernel, global, local, event, waitForEvents); } else { auto global = std::vector<size_t>{Ceil(CeilDiv(dest_one, db_["PAD_WPTX"]), db_["PAD_DIMX"]), Ceil(CeilDiv(dest_two, db_["PAD_WPTY"]), db_["PAD_DIMY"])}; auto local = std::vector<size_t>{db_["PAD_DIMX"], db_["PAD_DIMY"]}; - status = RunKernel(kernel, global, local); + status = RunKernel(kernel, global, local, event, waitForEvents); } } return status; |