diff options
Diffstat (limited to 'src/routine.cc')
-rw-r--r-- | src/routine.cc | 158 |
1 files changed, 14 insertions, 144 deletions
diff --git a/src/routine.cc b/src/routine.cc index ee3ba341..c59cbc11 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -30,9 +30,6 @@ Routine<T>::Routine(Queue &queue, EventPointer event, const std::string &name, context_(queue_.GetContext()), device_(queue_.GetDevice()), device_name_(device_.Name()), - max_work_item_dimensions_(device_.MaxWorkItemDimensions()), - max_work_item_sizes_(device_.MaxWorkItemSizes()), - max_work_group_size_(device_.MaxWorkGroupSize()), db_(queue_, routines, precision_) { } @@ -135,21 +132,21 @@ StatusCode Routine<T>::SetUp() { // ================================================================================================= // Enqueues a kernel, waits for completion, and checks for errors -template <typename T> -StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global, - const std::vector<size_t> &local, EventPointer event, - std::vector<Event>& waitForEvents) { +StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device, + std::vector<size_t> global, const std::vector<size_t> &local, + EventPointer event, std::vector<Event>& waitForEvents) { // Tests for validity of the local thread sizes - if (local.size() > max_work_item_dimensions_) { + if (local.size() > device.MaxWorkItemDimensions()) { return StatusCode::kInvalidLocalNumDimensions; } + const auto max_work_item_sizes = device.MaxWorkItemSizes(); for (auto i=size_t{0}; i<local.size(); ++i) { - if (local[i] > max_work_item_sizes_[i]) { return StatusCode::kInvalidLocalThreadsDim; } + if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; } } auto local_size = size_t{1}; for (auto &item: local) { local_size *= item; } - if (local_size > max_work_group_size_) { return StatusCode::kInvalidLocalThreadsTotal; } + if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; } // Make sure the global thread sizes are at least equal to the local sizes for (auto i=size_t{0}; i<global.size(); ++i) { @@ -157,12 +154,12 @@ StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global, } // Tests for local memory usage - const auto local_mem_usage = kernel.LocalMemUsage(device_); - if (!device_.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; } + const auto local_mem_usage = kernel.LocalMemUsage(device); + if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; } // Launches the kernel (and checks for launch errors) try { - kernel.Launch(queue_, global, local, event, waitForEvents); + kernel.Launch(queue, global, local, event, waitForEvents); } catch (...) { return StatusCode::kKernelLaunchError; } // No errors, normal termination of this function @@ -170,138 +167,11 @@ StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global, } // As above, but without an event waiting list -template <typename T> -StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global, - const std::vector<size_t> &local, EventPointer event) { +StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device, + std::vector<size_t> global, const std::vector<size_t> &local, + EventPointer event) { auto emptyWaitingList = std::vector<Event>(); - return RunKernel(kernel, global, local, event, emptyWaitingList); -} - -// ================================================================================================= - -// Copies or transposes a matrix and optionally pads/unpads it with zeros -template <typename T> -StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Event>& waitForEvents, - const size_t src_one, const size_t src_two, - const size_t src_ld, const size_t src_offset, - const Buffer<T> &src, - const size_t dest_one, const size_t dest_two, - const size_t dest_ld, const size_t dest_offset, - const Buffer<T> &dest, - const T alpha, - const Program &program, const bool do_pad, - const bool do_transpose, const bool do_conjugate, - const bool upper, const bool lower, - const bool diagonal_imag_zero) { - - // Determines whether or not the fast-version could potentially be used - auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) && - (src_one == dest_one) && (src_two == dest_two) && (src_ld == dest_ld) && - (upper == false) && (lower == false) && (diagonal_imag_zero == false); - - // Determines the right kernel - auto kernel_name = std::string{}; - if (do_transpose) { - if (use_fast_kernel && - IsMultiple(src_ld, db_["TRA_WPT"]) && - IsMultiple(src_one, db_["TRA_WPT"]*db_["TRA_WPT"]) && - IsMultiple(src_two, db_["TRA_WPT"]*db_["TRA_WPT"])) { - kernel_name = "TransposeMatrixFast"; - } - else { - use_fast_kernel = false; - kernel_name = (do_pad) ? "TransposePadMatrix" : "TransposeMatrix"; - } - } - else { - if (use_fast_kernel && - IsMultiple(src_ld, db_["COPY_VW"]) && - IsMultiple(src_one, db_["COPY_VW"]*db_["COPY_DIMX"]) && - IsMultiple(src_two, db_["COPY_WPT"]*db_["COPY_DIMY"])) { - kernel_name = "CopyMatrixFast"; - } - else { - use_fast_kernel = false; - kernel_name = (do_pad) ? "CopyPadMatrix" : "CopyMatrix"; - } - } - - // Upload the scalar argument as a constant buffer to the device (needed for half-precision) - auto alpha_buffer = Buffer<T>(context_, 1); - alpha_buffer.Write(queue_, 1, &alpha); - - // Retrieves the kernel from the compiled binary - try { - auto kernel = Kernel(program, kernel_name); - - // Sets the kernel arguments - if (use_fast_kernel) { - kernel.SetArgument(0, static_cast<int>(src_ld)); - kernel.SetArgument(1, src()); - kernel.SetArgument(2, dest()); - kernel.SetArgument(3, alpha_buffer()); - } - else { - kernel.SetArgument(0, static_cast<int>(src_one)); - kernel.SetArgument(1, static_cast<int>(src_two)); - kernel.SetArgument(2, static_cast<int>(src_ld)); - kernel.SetArgument(3, static_cast<int>(src_offset)); - kernel.SetArgument(4, src()); - kernel.SetArgument(5, static_cast<int>(dest_one)); - kernel.SetArgument(6, static_cast<int>(dest_two)); - kernel.SetArgument(7, static_cast<int>(dest_ld)); - kernel.SetArgument(8, static_cast<int>(dest_offset)); - kernel.SetArgument(9, dest()); - kernel.SetArgument(10, alpha_buffer()); - if (do_pad) { - kernel.SetArgument(11, static_cast<int>(do_conjugate)); - } - else { - kernel.SetArgument(11, static_cast<int>(upper)); - kernel.SetArgument(12, static_cast<int>(lower)); - kernel.SetArgument(13, static_cast<int>(diagonal_imag_zero)); - } - } - - // Launches the kernel and returns the error code. Uses global and local thread sizes based on - // parameters in the database. - if (do_transpose) { - if (use_fast_kernel) { - const auto global = std::vector<size_t>{ - dest_one / db_["TRA_WPT"], - dest_two / db_["TRA_WPT"] - }; - const auto local = std::vector<size_t>{db_["TRA_DIM"], db_["TRA_DIM"]}; - return RunKernel(kernel, global, local, event, waitForEvents); - } - else { - const auto global = std::vector<size_t>{ - Ceil(CeilDiv(dest_one, db_["PADTRA_WPT"]), db_["PADTRA_TILE"]), - Ceil(CeilDiv(dest_two, db_["PADTRA_WPT"]), db_["PADTRA_TILE"]) - }; - const auto local = std::vector<size_t>{db_["PADTRA_TILE"], db_["PADTRA_TILE"]}; - return RunKernel(kernel, global, local, event, waitForEvents); - } - } - else { - if (use_fast_kernel) { - const auto global = std::vector<size_t>{ - dest_one / db_["COPY_VW"], - dest_two / db_["COPY_WPT"] - }; - const auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]}; - return RunKernel(kernel, global, local, event, waitForEvents); - } - else { - const auto global = std::vector<size_t>{ - Ceil(CeilDiv(dest_one, db_["PAD_WPTX"]), db_["PAD_DIMX"]), - Ceil(CeilDiv(dest_two, db_["PAD_WPTY"]), db_["PAD_DIMY"]) - }; - const auto local = std::vector<size_t>{db_["PAD_DIMX"], db_["PAD_DIMY"]}; - return RunKernel(kernel, global, local, event, waitForEvents); - } - } - } catch (...) { return StatusCode::kInvalidKernel; } + return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList); } // ================================================================================================= |