summaryrefslogtreecommitdiff
path: root/src/routine.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/routine.cc')
-rw-r--r--src/routine.cc10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/routine.cc b/src/routine.cc
index a4e0bb37..4b7ece41 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -210,11 +210,13 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr
const size_t dest_ld, const size_t dest_offset,
const Buffer &dest,
const bool do_transpose, const bool do_conjugate,
- const bool pad, const Program &program) {
+ const bool pad, const bool upper, const bool lower,
+ const Program &program) {
// Determines whether or not the fast-version could potentially be used
auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) &&
- (src_one == dest_one) && (src_two == dest_two) && (src_ld == dest_ld);
+ (src_one == dest_one) && (src_two == dest_two) && (src_ld == dest_ld) &&
+ (upper == false) && (lower == false);
// Determines the right kernel
auto kernel_name = std::string{};
@@ -267,6 +269,10 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr
if (pad) {
kernel.SetArgument(10, static_cast<int>(do_conjugate));
}
+ else {
+ kernel.SetArgument(10, static_cast<int>(upper));
+ kernel.SetArgument(11, static_cast<int>(lower));
+ }
}
// Launches the kernel and returns the error code. Uses global and local thread sizes based on