From 8f01c644b5c62958c1dcd4fd72b411f3805b81a6 Mon Sep 17 00:00:00 2001 From: CNugteren Date: Tue, 16 Jun 2015 07:43:19 +0200 Subject: Added support for complex conjugate transpose --- src/kernels/common.opencl | 7 +++++++ src/kernels/pad.opencl | 4 +++- src/kernels/padtranspose.opencl | 6 +++++- src/routine.cc | 9 ++++++--- src/routines/xgemm.cc | 12 ++++++++---- 5 files changed, 29 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index 154265e4..818c725f 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -112,6 +112,13 @@ R"( #define AXPBY(e, a, b, c, d) e = a*b + c*d #endif +// The complex conjugate operation for complex transforms +#if PRECISION == 3232 || PRECISION == 6464 + #define COMPLEX_CONJUGATE(value) value.x = value.x; value.y = -value.y +#else + #define COMPLEX_CONJUGATE(value) value = value +#endif + // ================================================================================================= // End of the C++11 raw string literal diff --git a/src/kernels/pad.opencl b/src/kernels/pad.opencl index ccaeb9d6..45eaef91 100644 --- a/src/kernels/pad.opencl +++ b/src/kernels/pad.opencl @@ -47,7 +47,8 @@ __kernel void PadMatrix(const int src_one, const int src_two, __global const real* restrict src, const int dest_one, const int dest_two, const int dest_ld, const int dest_offset, - __global real* dest) { + __global real* dest, + const int do_conjugate) { // Loops over the work per thread in both dimensions #pragma unroll @@ -67,6 +68,7 @@ __kernel void PadMatrix(const int src_one, const int src_two, } // Stores the value in the destination matrix + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } dest[id_two*dest_ld + id_one + dest_offset] = value; } } diff --git a/src/kernels/padtranspose.opencl b/src/kernels/padtranspose.opencl index 67cbf341..2f2aabd6 100644 --- a/src/kernels/padtranspose.opencl +++ b/src/kernels/padtranspose.opencl @@ -40,7 +40,8 @@ __kernel void PadTransposeMatrix(const int src_one, const int src_two, __global const real* restrict src, const int dest_one, const int dest_two, const int dest_ld, const int dest_offset, - __global real* dest) { + __global real* dest, + const int do_conjugate) { // Local memory to store a tile of the matrix (for coalescing) __local real tile[PADTRA_WPT*PADTRA_TILE][PADTRA_WPT*PADTRA_TILE + PADTRA_PAD]; @@ -83,12 +84,15 @@ __kernel void PadTransposeMatrix(const int src_one, const int src_two, // Stores the transposed value in the destination matrix if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) { real value = tile[get_local_id(0)*PADTRA_WPT + w_two][get_local_id(1)*PADTRA_WPT + w_one]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } dest[id_dest_two*dest_ld + id_dest_one + dest_offset] = value; } } } } +// ================================================================================================= + // Same as UnPadCopyMatrix, but now also does the transpose __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) __kernel void UnPadTransposeMatrix(const int src_one, const int src_two, diff --git a/src/routine.cc b/src/routine.cc index 32face4a..064db754 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -209,11 +209,11 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr const size_t dest_one, const size_t dest_two, const size_t dest_ld, const size_t dest_offset, const Buffer &dest, - const bool do_transpose, const bool pad, - const Program &program) { + const bool do_transpose, const bool do_conjugate, + const bool pad, const Program &program) { // Determines whether or not the fast-version could potentially be used - auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && + auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) && (src_one == dest_one) && (src_two == dest_two) && (src_ld == dest_ld); // Determines the right kernel @@ -264,6 +264,9 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr kernel.SetArgument(7, static_cast(dest_ld)); kernel.SetArgument(8, static_cast(dest_offset)); kernel.SetArgument(9, dest()); + if (pad) { + kernel.SetArgument(10, static_cast(do_conjugate)); + } } // Launches the kernel and returns the error code. Uses global and local thread sizes based on diff --git a/src/routines/xgemm.cc b/src/routines/xgemm.cc index 16bbc154..db10899c 100644 --- a/src/routines/xgemm.cc +++ b/src/routines/xgemm.cc @@ -63,6 +63,10 @@ StatusCode Xgemm::DoGemm(const Layout layout, auto b_do_transpose = !b_rotated; auto c_do_transpose = c_rotated; + // In case of complex data-types, the transpose can also become a conjugate transpose + auto a_conjugate = (a_transpose == Transpose::kConjugate); + auto b_conjugate = (b_transpose == Transpose::kConjugate); + // Computes the first and second dimensions of the 3 matrices taking into account whether the // matrices are rotated or not auto a_one = (a_rotated) ? k : m; @@ -104,18 +108,18 @@ StatusCode Xgemm::DoGemm(const Layout layout, // them up until they reach a certain multiple of size (kernel parameter dependent). status = PadCopyTransposeMatrix(a_one, a_two, a_ld, a_offset, a_buffer, m_ceiled, k_ceiled, m_ceiled, 0, temp_a, - a_do_transpose, true, program); + a_do_transpose, a_conjugate, true, program); if (ErrorIn(status)) { return status; } status = PadCopyTransposeMatrix(b_one, b_two, b_ld, b_offset, b_buffer, n_ceiled, k_ceiled, n_ceiled, 0, temp_b, - b_do_transpose, true, program); + b_do_transpose, b_conjugate, true, program); if (ErrorIn(status)) { return status; } // Only necessary for matrix C if it used both as input and output if (beta != static_cast(0)) { status = PadCopyTransposeMatrix(c_one, c_two, c_ld, c_offset, c_buffer, m_ceiled, n_ceiled, m_ceiled, 0, temp_c, - c_do_transpose, true, program); + c_do_transpose, false, true, program); if (ErrorIn(status)) { return status; } } @@ -147,7 +151,7 @@ StatusCode Xgemm::DoGemm(const Layout layout, // Runs the post-processing kernel status = PadCopyTransposeMatrix(m_ceiled, n_ceiled, m_ceiled, 0, temp_c, c_one, c_two, c_ld, c_offset, c_buffer, - c_do_transpose, false, program); + c_do_transpose, false, false, program); if (ErrorIn(status)) { return status; } // Successfully finished the computation -- cgit v1.2.3