summaryrefslogtreecommitdiff
path: root/src/routine.cc
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-06-16 18:07:46 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-06-16 18:07:46 +0200
commit52ccaf5b25e14c9ce032315e5e96b1f27886d481 (patch)
tree087288b7aebf2a06ffc4e7dcbcd4353f7a3be6a7 /src/routine.cc
parent39b7dbc5e37829abfbcfb77852b9138b31540b42 (diff)
Added XOMATCOPY routines to perform out-of-place matrix scaling, copying, and/or transposing
Diffstat (limited to 'src/routine.cc')
-rw-r--r--src/routine.cc15
1 files changed, 11 insertions, 4 deletions
diff --git a/src/routine.cc b/src/routine.cc
index 4b334e60..1cf8bff8 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -302,6 +302,7 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
const size_t dest_one, const size_t dest_two,
const size_t dest_ld, const size_t dest_offset,
const Buffer<T> &dest,
+ const T alpha,
const Program &program, const bool do_pad,
const bool do_transpose, const bool do_conjugate,
const bool upper, const bool lower,
@@ -339,6 +340,10 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
}
}
+ // Upload the scalar argument as a constant buffer to the device (needed for half-precision)
+ auto alpha_buffer = Buffer<T>(context_, 1);
+ alpha_buffer.Write(queue_, 1, &alpha);
+
// Retrieves the kernel from the compiled binary
try {
auto kernel = Kernel(program, kernel_name);
@@ -348,6 +353,7 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
kernel.SetArgument(0, static_cast<int>(src_ld));
kernel.SetArgument(1, src());
kernel.SetArgument(2, dest());
+ kernel.SetArgument(3, alpha_buffer());
}
else {
kernel.SetArgument(0, static_cast<int>(src_one));
@@ -360,13 +366,14 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
kernel.SetArgument(7, static_cast<int>(dest_ld));
kernel.SetArgument(8, static_cast<int>(dest_offset));
kernel.SetArgument(9, dest());
+ kernel.SetArgument(10, alpha_buffer());
if (do_pad) {
- kernel.SetArgument(10, static_cast<int>(do_conjugate));
+ kernel.SetArgument(11, static_cast<int>(do_conjugate));
}
else {
- kernel.SetArgument(10, static_cast<int>(upper));
- kernel.SetArgument(11, static_cast<int>(lower));
- kernel.SetArgument(12, static_cast<int>(diagonal_imag_zero));
+ kernel.SetArgument(11, static_cast<int>(upper));
+ kernel.SetArgument(12, static_cast<int>(lower));
+ kernel.SetArgument(13, static_cast<int>(diagonal_imag_zero));
}
}