diff options
Diffstat (limited to 'src/routines/common.hpp')
-rw-r--r-- | src/routines/common.hpp | 8 |
1 files changed, 2 insertions, 6 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp index c99cd39d..e624a2b1 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -88,10 +88,6 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Cont } } - // Upload the scalar argument as a constant buffer to the device (needed for half-precision) - auto alpha_buffer = Buffer<T>(context, 1); - alpha_buffer.Write(queue, 1, &alpha); - // Retrieves the kernel from the compiled binary try { auto kernel = Kernel(program, kernel_name); @@ -101,7 +97,7 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Cont kernel.SetArgument(0, static_cast<int>(src_ld)); kernel.SetArgument(1, src()); kernel.SetArgument(2, dest()); - kernel.SetArgument(3, alpha_buffer()); + kernel.SetArgument(3, GetRealArg(alpha)); } else { kernel.SetArgument(0, static_cast<int>(src_one)); @@ -114,7 +110,7 @@ StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Cont kernel.SetArgument(7, static_cast<int>(dest_ld)); kernel.SetArgument(8, static_cast<int>(dest_offset)); kernel.SetArgument(9, dest()); - kernel.SetArgument(10, alpha_buffer()); + kernel.SetArgument(10, GetRealArg(alpha)); if (do_pad) { kernel.SetArgument(11, static_cast<int>(do_conjugate)); } |