summaryrefslogtreecommitdiff
path: root/src/routines/common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/common.hpp')
-rw-r--r--src/routines/common.hpp162
1 files changed, 80 insertions, 82 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index 9d8849c3..53ca6355 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -27,29 +27,29 @@ namespace clblast {
// =================================================================================================
// Enqueues a kernel, waits for completion, and checks for errors
-StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
- std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event, const std::vector<Event> &waitForEvents = {});
+void RunKernel(Kernel &kernel, Queue &queue, const Device &device,
+ std::vector<size_t> global, const std::vector<size_t> &local,
+ EventPointer event, const std::vector<Event> &waitForEvents = {});
// =================================================================================================
// Copies or transposes a matrix and optionally pads/unpads it with zeros. This method is also able
// to write to symmetric and triangular matrices through optional arguments.
template <typename T>
-StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device,
- const Database &db,
- EventPointer event, const std::vector<Event> &waitForEvents,
- const size_t src_one, const size_t src_two,
- const size_t src_ld, const size_t src_offset,
- const Buffer<T> &src,
- const size_t dest_one, const size_t dest_two,
- const size_t dest_ld, const size_t dest_offset,
- const Buffer<T> &dest,
- const T alpha,
- const Program &program, const bool do_pad,
- const bool do_transpose, const bool do_conjugate,
- const bool upper = false, const bool lower = false,
- const bool diagonal_imag_zero = false) {
+void PadCopyTransposeMatrix(Queue &queue, const Device &device,
+ const Database &db,
+ EventPointer event, const std::vector<Event> &waitForEvents,
+ const size_t src_one, const size_t src_two,
+ const size_t src_ld, const size_t src_offset,
+ const Buffer<T> &src,
+ const size_t dest_one, const size_t dest_two,
+ const size_t dest_ld, const size_t dest_offset,
+ const Buffer<T> &dest,
+ const T alpha,
+ const Program &program, const bool do_pad,
+ const bool do_transpose, const bool do_conjugate,
+ const bool upper = false, const bool lower = false,
+ const bool diagonal_imag_zero = false) {
// Determines whether or not the fast-version could potentially be used
auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) &&
@@ -61,8 +61,8 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device,
if (do_transpose) {
if (use_fast_kernel &&
IsMultiple(src_ld, db["TRA_WPT"]) &&
- IsMultiple(src_one, db["TRA_WPT"]*db["TRA_WPT"]) &&
- IsMultiple(src_two, db["TRA_WPT"]*db["TRA_WPT"])) {
+ IsMultiple(src_one, db["TRA_WPT"]*db["TRA_DIM"]) &&
+ IsMultiple(src_two, db["TRA_WPT"]*db["TRA_DIM"])) {
kernel_name = "TransposeMatrixFast";
}
else {
@@ -84,77 +84,75 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device,
}
// Retrieves the kernel from the compiled binary
- try {
- auto kernel = Kernel(program, kernel_name);
+ auto kernel = Kernel(program, kernel_name);
- // Sets the kernel arguments
- if (use_fast_kernel) {
- kernel.SetArgument(0, static_cast<int>(src_ld));
- kernel.SetArgument(1, src());
- kernel.SetArgument(2, dest());
- kernel.SetArgument(3, GetRealArg(alpha));
+ // Sets the kernel arguments
+ if (use_fast_kernel) {
+ kernel.SetArgument(0, static_cast<int>(src_ld));
+ kernel.SetArgument(1, src());
+ kernel.SetArgument(2, dest());
+ kernel.SetArgument(3, GetRealArg(alpha));
+ }
+ else {
+ kernel.SetArgument(0, static_cast<int>(src_one));
+ kernel.SetArgument(1, static_cast<int>(src_two));
+ kernel.SetArgument(2, static_cast<int>(src_ld));
+ kernel.SetArgument(3, static_cast<int>(src_offset));
+ kernel.SetArgument(4, src());
+ kernel.SetArgument(5, static_cast<int>(dest_one));
+ kernel.SetArgument(6, static_cast<int>(dest_two));
+ kernel.SetArgument(7, static_cast<int>(dest_ld));
+ kernel.SetArgument(8, static_cast<int>(dest_offset));
+ kernel.SetArgument(9, dest());
+ kernel.SetArgument(10, GetRealArg(alpha));
+ if (do_pad) {
+ kernel.SetArgument(11, static_cast<int>(do_conjugate));
}
else {
- kernel.SetArgument(0, static_cast<int>(src_one));
- kernel.SetArgument(1, static_cast<int>(src_two));
- kernel.SetArgument(2, static_cast<int>(src_ld));
- kernel.SetArgument(3, static_cast<int>(src_offset));
- kernel.SetArgument(4, src());
- kernel.SetArgument(5, static_cast<int>(dest_one));
- kernel.SetArgument(6, static_cast<int>(dest_two));
- kernel.SetArgument(7, static_cast<int>(dest_ld));
- kernel.SetArgument(8, static_cast<int>(dest_offset));
- kernel.SetArgument(9, dest());
- kernel.SetArgument(10, GetRealArg(alpha));
- if (do_pad) {
- kernel.SetArgument(11, static_cast<int>(do_conjugate));
- }
- else {
- kernel.SetArgument(11, static_cast<int>(upper));
- kernel.SetArgument(12, static_cast<int>(lower));
- kernel.SetArgument(13, static_cast<int>(diagonal_imag_zero));
- }
+ kernel.SetArgument(11, static_cast<int>(upper));
+ kernel.SetArgument(12, static_cast<int>(lower));
+ kernel.SetArgument(13, static_cast<int>(diagonal_imag_zero));
}
+ }
- // Launches the kernel and returns the error code. Uses global and local thread sizes based on
- // parameters in the database.
- if (do_transpose) {
- if (use_fast_kernel) {
- const auto global = std::vector<size_t>{
- dest_one / db["TRA_WPT"],
- dest_two / db["TRA_WPT"]
- };
- const auto local = std::vector<size_t>{db["TRA_DIM"], db["TRA_DIM"]};
- return RunKernel(kernel, queue, device, global, local, event, waitForEvents);
- }
- else {
- const auto global = std::vector<size_t>{
- Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
- Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"])
- };
- const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"]};
- return RunKernel(kernel, queue, device, global, local, event, waitForEvents);
- }
+ // Launches the kernel and returns the error code. Uses global and local thread sizes based on
+ // parameters in the database.
+ if (do_transpose) {
+ if (use_fast_kernel) {
+ const auto global = std::vector<size_t>{
+ dest_one / db["TRA_WPT"],
+ dest_two / db["TRA_WPT"]
+ };
+ const auto local = std::vector<size_t>{db["TRA_DIM"], db["TRA_DIM"]};
+ RunKernel(kernel, queue, device, global, local, event, waitForEvents);
}
else {
- if (use_fast_kernel) {
- const auto global = std::vector<size_t>{
- dest_one / db["COPY_VW"],
- dest_two / db["COPY_WPT"]
- };
- const auto local = std::vector<size_t>{db["COPY_DIMX"], db["COPY_DIMY"]};
- return RunKernel(kernel, queue, device, global, local, event, waitForEvents);
- }
- else {
- const auto global = std::vector<size_t>{
- Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]),
- Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"])
- };
- const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"]};
- return RunKernel(kernel, queue, device, global, local, event, waitForEvents);
- }
+ const auto global = std::vector<size_t>{
+ Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
+ Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"])
+ };
+ const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"]};
+ RunKernel(kernel, queue, device, global, local, event, waitForEvents);
}
- } catch (...) { return StatusCode::kInvalidKernel; }
+ }
+ else {
+ if (use_fast_kernel) {
+ const auto global = std::vector<size_t>{
+ dest_one / db["COPY_VW"],
+ dest_two / db["COPY_WPT"]
+ };
+ const auto local = std::vector<size_t>{db["COPY_DIMX"], db["COPY_DIMY"]};
+ RunKernel(kernel, queue, device, global, local, event, waitForEvents);
+ }
+ else {
+ const auto global = std::vector<size_t>{
+ Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]),
+ Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"])
+ };
+ const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"]};
+ RunKernel(kernel, queue, device, global, local, event, waitForEvents);
+ }
+ }
}
// =================================================================================================