diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/routines/common.hpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp index c30a2e0e..c6db0152 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -76,6 +76,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, // Determines the right kernel auto kernel_name = std::string{}; + auto pad_kernel = false; if (do_transpose) { if (use_fast_kernel && IsMultiple(src_ld, db["TRA_WPT"]) && @@ -85,7 +86,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, } else { use_fast_kernel = false; - kernel_name = (do_pad) ? "TransposePadMatrix" : "TransposeMatrix"; + pad_kernel = (do_pad || do_conjugate); + kernel_name = (pad_kernel) ? "TransposePadMatrix" : "TransposeMatrix"; } } else { @@ -97,7 +99,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, } else { use_fast_kernel = false; - kernel_name = (do_pad) ? "CopyPadMatrix" : "CopyMatrix"; + pad_kernel = do_pad; + kernel_name = (pad_kernel) ? "CopyPadMatrix" : "CopyMatrix"; } } @@ -123,7 +126,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, kernel.SetArgument(8, static_cast<int>(dest_offset)); kernel.SetArgument(9, dest()); kernel.SetArgument(10, GetRealArg(alpha)); - if (do_pad) { + if (pad_kernel) { kernel.SetArgument(11, static_cast<int>(do_conjugate)); } else { |