summaryrefslogtreecommitdiff
path: root/src/routine.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/routine.cc')
-rw-r--r--src/routine.cc158
1 files changed, 14 insertions, 144 deletions
diff --git a/src/routine.cc b/src/routine.cc
index ee3ba341..c59cbc11 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -30,9 +30,6 @@ Routine<T>::Routine(Queue &queue, EventPointer event, const std::string &name,
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
device_name_(device_.Name()),
- max_work_item_dimensions_(device_.MaxWorkItemDimensions()),
- max_work_item_sizes_(device_.MaxWorkItemSizes()),
- max_work_group_size_(device_.MaxWorkGroupSize()),
db_(queue_, routines, precision_) {
}
@@ -135,21 +132,21 @@ StatusCode Routine<T>::SetUp() {
// =================================================================================================
// Enqueues a kernel, waits for completion, and checks for errors
-template <typename T>
-StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global,
- const std::vector<size_t> &local, EventPointer event,
- std::vector<Event>& waitForEvents) {
+StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
+ std::vector<size_t> global, const std::vector<size_t> &local,
+ EventPointer event, std::vector<Event>& waitForEvents) {
// Tests for validity of the local thread sizes
- if (local.size() > max_work_item_dimensions_) {
+ if (local.size() > device.MaxWorkItemDimensions()) {
return StatusCode::kInvalidLocalNumDimensions;
}
+ const auto max_work_item_sizes = device.MaxWorkItemSizes();
for (auto i=size_t{0}; i<local.size(); ++i) {
- if (local[i] > max_work_item_sizes_[i]) { return StatusCode::kInvalidLocalThreadsDim; }
+ if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
}
auto local_size = size_t{1};
for (auto &item: local) { local_size *= item; }
- if (local_size > max_work_group_size_) { return StatusCode::kInvalidLocalThreadsTotal; }
+ if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
// Make sure the global thread sizes are at least equal to the local sizes
for (auto i=size_t{0}; i<global.size(); ++i) {
@@ -157,12 +154,12 @@ StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global,
}
// Tests for local memory usage
- const auto local_mem_usage = kernel.LocalMemUsage(device_);
- if (!device_.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }
+ const auto local_mem_usage = kernel.LocalMemUsage(device);
+ if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }
// Launches the kernel (and checks for launch errors)
try {
- kernel.Launch(queue_, global, local, event, waitForEvents);
+ kernel.Launch(queue, global, local, event, waitForEvents);
} catch (...) { return StatusCode::kKernelLaunchError; }
// No errors, normal termination of this function
@@ -170,138 +167,11 @@ StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global,
}
// As above, but without an event waiting list
-template <typename T>
-StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> global,
- const std::vector<size_t> &local, EventPointer event) {
+StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
+ std::vector<size_t> global, const std::vector<size_t> &local,
+ EventPointer event) {
auto emptyWaitingList = std::vector<Event>();
- return RunKernel(kernel, global, local, event, emptyWaitingList);
-}
-
-// =================================================================================================
-
-// Copies or transposes a matrix and optionally pads/unpads it with zeros
-template <typename T>
-StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Event>& waitForEvents,
- const size_t src_one, const size_t src_two,
- const size_t src_ld, const size_t src_offset,
- const Buffer<T> &src,
- const size_t dest_one, const size_t dest_two,
- const size_t dest_ld, const size_t dest_offset,
- const Buffer<T> &dest,
- const T alpha,
- const Program &program, const bool do_pad,
- const bool do_transpose, const bool do_conjugate,
- const bool upper, const bool lower,
- const bool diagonal_imag_zero) {
-
- // Determines whether or not the fast-version could potentially be used
- auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) &&
- (src_one == dest_one) && (src_two == dest_two) && (src_ld == dest_ld) &&
- (upper == false) && (lower == false) && (diagonal_imag_zero == false);
-
- // Determines the right kernel
- auto kernel_name = std::string{};
- if (do_transpose) {
- if (use_fast_kernel &&
- IsMultiple(src_ld, db_["TRA_WPT"]) &&
- IsMultiple(src_one, db_["TRA_WPT"]*db_["TRA_WPT"]) &&
- IsMultiple(src_two, db_["TRA_WPT"]*db_["TRA_WPT"])) {
- kernel_name = "TransposeMatrixFast";
- }
- else {
- use_fast_kernel = false;
- kernel_name = (do_pad) ? "TransposePadMatrix" : "TransposeMatrix";
- }
- }
- else {
- if (use_fast_kernel &&
- IsMultiple(src_ld, db_["COPY_VW"]) &&
- IsMultiple(src_one, db_["COPY_VW"]*db_["COPY_DIMX"]) &&
- IsMultiple(src_two, db_["COPY_WPT"]*db_["COPY_DIMY"])) {
- kernel_name = "CopyMatrixFast";
- }
- else {
- use_fast_kernel = false;
- kernel_name = (do_pad) ? "CopyPadMatrix" : "CopyMatrix";
- }
- }
-
- // Upload the scalar argument as a constant buffer to the device (needed for half-precision)
- auto alpha_buffer = Buffer<T>(context_, 1);
- alpha_buffer.Write(queue_, 1, &alpha);
-
- // Retrieves the kernel from the compiled binary
- try {
- auto kernel = Kernel(program, kernel_name);
-
- // Sets the kernel arguments
- if (use_fast_kernel) {
- kernel.SetArgument(0, static_cast<int>(src_ld));
- kernel.SetArgument(1, src());
- kernel.SetArgument(2, dest());
- kernel.SetArgument(3, alpha_buffer());
- }
- else {
- kernel.SetArgument(0, static_cast<int>(src_one));
- kernel.SetArgument(1, static_cast<int>(src_two));
- kernel.SetArgument(2, static_cast<int>(src_ld));
- kernel.SetArgument(3, static_cast<int>(src_offset));
- kernel.SetArgument(4, src());
- kernel.SetArgument(5, static_cast<int>(dest_one));
- kernel.SetArgument(6, static_cast<int>(dest_two));
- kernel.SetArgument(7, static_cast<int>(dest_ld));
- kernel.SetArgument(8, static_cast<int>(dest_offset));
- kernel.SetArgument(9, dest());
- kernel.SetArgument(10, alpha_buffer());
- if (do_pad) {
- kernel.SetArgument(11, static_cast<int>(do_conjugate));
- }
- else {
- kernel.SetArgument(11, static_cast<int>(upper));
- kernel.SetArgument(12, static_cast<int>(lower));
- kernel.SetArgument(13, static_cast<int>(diagonal_imag_zero));
- }
- }
-
- // Launches the kernel and returns the error code. Uses global and local thread sizes based on
- // parameters in the database.
- if (do_transpose) {
- if (use_fast_kernel) {
- const auto global = std::vector<size_t>{
- dest_one / db_["TRA_WPT"],
- dest_two / db_["TRA_WPT"]
- };
- const auto local = std::vector<size_t>{db_["TRA_DIM"], db_["TRA_DIM"]};
- return RunKernel(kernel, global, local, event, waitForEvents);
- }
- else {
- const auto global = std::vector<size_t>{
- Ceil(CeilDiv(dest_one, db_["PADTRA_WPT"]), db_["PADTRA_TILE"]),
- Ceil(CeilDiv(dest_two, db_["PADTRA_WPT"]), db_["PADTRA_TILE"])
- };
- const auto local = std::vector<size_t>{db_["PADTRA_TILE"], db_["PADTRA_TILE"]};
- return RunKernel(kernel, global, local, event, waitForEvents);
- }
- }
- else {
- if (use_fast_kernel) {
- const auto global = std::vector<size_t>{
- dest_one / db_["COPY_VW"],
- dest_two / db_["COPY_WPT"]
- };
- const auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]};
- return RunKernel(kernel, global, local, event, waitForEvents);
- }
- else {
- const auto global = std::vector<size_t>{
- Ceil(CeilDiv(dest_one, db_["PAD_WPTX"]), db_["PAD_DIMX"]),
- Ceil(CeilDiv(dest_two, db_["PAD_WPTY"]), db_["PAD_DIMY"])
- };
- const auto local = std::vector<size_t>{db_["PAD_DIMX"], db_["PAD_DIMY"]};
- return RunKernel(kernel, global, local, event, waitForEvents);
- }
- }
- } catch (...) { return StatusCode::kInvalidKernel; }
+ return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
}
// =================================================================================================