diff options
Diffstat (limited to 'src/routines/common.hpp')
-rw-r--r-- | src/routines/common.hpp | 19 |
1 files changed, 5 insertions, 14 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp index c99cd39d..9d8849c3 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -29,21 +29,16 @@ namespace clblast { // Enqueues a kernel, waits for completion, and checks for errors StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device, std::vector<size_t> global, const std::vector<size_t> &local, - EventPointer event, std::vector<Event>& waitForEvents); - -// As above, but without an event waiting list -StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device, - std::vector<size_t> global, const std::vector<size_t> &local, - EventPointer event); + EventPointer event, const std::vector<Event> &waitForEvents = {}); // ================================================================================================= // Copies or transposes a matrix and optionally pads/unpads it with zeros. This method is also able // to write to symmetric and triangular matrices through optional arguments. template <typename T> -StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Context &context, +StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Database &db, - EventPointer event, std::vector<Event>& waitForEvents, + EventPointer event, const std::vector<Event> &waitForEvents, const size_t src_one, const size_t src_two, const size_t src_ld, const size_t src_offset, const Buffer<T> &src, @@ -88,10 +83,6 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Cont } } - // Upload the scalar argument as a constant buffer to the device (needed for half-precision) - auto alpha_buffer = Buffer<T>(context, 1); - alpha_buffer.Write(queue, 1, &alpha); - // Retrieves the kernel from the compiled binary try { auto kernel = Kernel(program, kernel_name); @@ -101,7 +92,7 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Cont kernel.SetArgument(0, static_cast<int>(src_ld)); kernel.SetArgument(1, src()); kernel.SetArgument(2, dest()); - kernel.SetArgument(3, alpha_buffer()); + kernel.SetArgument(3, GetRealArg(alpha)); } else { kernel.SetArgument(0, static_cast<int>(src_one)); @@ -114,7 +105,7 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Cont kernel.SetArgument(7, static_cast<int>(dest_ld)); kernel.SetArgument(8, static_cast<int>(dest_offset)); kernel.SetArgument(9, dest()); - kernel.SetArgument(10, alpha_buffer()); + kernel.SetArgument(10, GetRealArg(alpha)); if (do_pad) { kernel.SetArgument(11, static_cast<int>(do_conjugate)); } |