summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-06-16 07:43:19 +0200
committerCNugteren <web@cedricnugteren.nl>2015-06-16 07:43:19 +0200
commit8f01c644b5c62958c1dcd4fd72b411f3805b81a6 (patch)
treed3e5e937904a5206c503769c38cc11912b12a3ad /src
parent9e2fba9ab9cab1f94dfe143fc6e163f47b6d6f39 (diff)
Added support for complex conjugate transpose
Diffstat (limited to 'src')
-rw-r--r--src/kernels/common.opencl7
-rw-r--r--src/kernels/pad.opencl4
-rw-r--r--src/kernels/padtranspose.opencl6
-rw-r--r--src/routine.cc9
-rw-r--r--src/routines/xgemm.cc12
5 files changed, 29 insertions, 9 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index 154265e4..818c725f 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -112,6 +112,13 @@ R"(
#define AXPBY(e, a, b, c, d) e = a*b + c*d
#endif
+// The complex conjugate operation for complex transforms
+#if PRECISION == 3232 || PRECISION == 6464
+ #define COMPLEX_CONJUGATE(value) value.x = value.x; value.y = -value.y
+#else
+ #define COMPLEX_CONJUGATE(value) value = value
+#endif
+
// =================================================================================================
// End of the C++11 raw string literal
diff --git a/src/kernels/pad.opencl b/src/kernels/pad.opencl
index ccaeb9d6..45eaef91 100644
--- a/src/kernels/pad.opencl
+++ b/src/kernels/pad.opencl
@@ -47,7 +47,8 @@ __kernel void PadMatrix(const int src_one, const int src_two,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset,
- __global real* dest) {
+ __global real* dest,
+ const int do_conjugate) {
// Loops over the work per thread in both dimensions
#pragma unroll
@@ -67,6 +68,7 @@ __kernel void PadMatrix(const int src_one, const int src_two,
}
// Stores the value in the destination matrix
+ if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
dest[id_two*dest_ld + id_one + dest_offset] = value;
}
}
diff --git a/src/kernels/padtranspose.opencl b/src/kernels/padtranspose.opencl
index 67cbf341..2f2aabd6 100644
--- a/src/kernels/padtranspose.opencl
+++ b/src/kernels/padtranspose.opencl
@@ -40,7 +40,8 @@ __kernel void PadTransposeMatrix(const int src_one, const int src_two,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset,
- __global real* dest) {
+ __global real* dest,
+ const int do_conjugate) {
// Local memory to store a tile of the matrix (for coalescing)
__local real tile[PADTRA_WPT*PADTRA_TILE][PADTRA_WPT*PADTRA_TILE + PADTRA_PAD];
@@ -83,12 +84,15 @@ __kernel void PadTransposeMatrix(const int src_one, const int src_two,
// Stores the transposed value in the destination matrix
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
real value = tile[get_local_id(0)*PADTRA_WPT + w_two][get_local_id(1)*PADTRA_WPT + w_one];
+ if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
dest[id_dest_two*dest_ld + id_dest_one + dest_offset] = value;
}
}
}
}
+// =================================================================================================
+
// Same as UnPadCopyMatrix, but now also does the transpose
__attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
__kernel void UnPadTransposeMatrix(const int src_one, const int src_two,
diff --git a/src/routine.cc b/src/routine.cc
index 32face4a..064db754 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -209,11 +209,11 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr
const size_t dest_one, const size_t dest_two,
const size_t dest_ld, const size_t dest_offset,
const Buffer &dest,
- const bool do_transpose, const bool pad,
- const Program &program) {
+ const bool do_transpose, const bool do_conjugate,
+ const bool pad, const Program &program) {
// Determines whether or not the fast-version could potentially be used
- auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) &&
+ auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) &&
(src_one == dest_one) && (src_two == dest_two) && (src_ld == dest_ld);
// Determines the right kernel
@@ -264,6 +264,9 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr
kernel.SetArgument(7, static_cast<int>(dest_ld));
kernel.SetArgument(8, static_cast<int>(dest_offset));
kernel.SetArgument(9, dest());
+ if (pad) {
+ kernel.SetArgument(10, static_cast<int>(do_conjugate));
+ }
}
// Launches the kernel and returns the error code. Uses global and local thread sizes based on
diff --git a/src/routines/xgemm.cc b/src/routines/xgemm.cc
index 16bbc154..db10899c 100644
--- a/src/routines/xgemm.cc
+++ b/src/routines/xgemm.cc
@@ -63,6 +63,10 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
auto b_do_transpose = !b_rotated;
auto c_do_transpose = c_rotated;
+ // In case of complex data-types, the transpose can also become a conjugate transpose
+ auto a_conjugate = (a_transpose == Transpose::kConjugate);
+ auto b_conjugate = (b_transpose == Transpose::kConjugate);
+
// Computes the first and second dimensions of the 3 matrices taking into account whether the
// matrices are rotated or not
auto a_one = (a_rotated) ? k : m;
@@ -104,18 +108,18 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
// them up until they reach a certain multiple of size (kernel parameter dependent).
status = PadCopyTransposeMatrix(a_one, a_two, a_ld, a_offset, a_buffer,
m_ceiled, k_ceiled, m_ceiled, 0, temp_a,
- a_do_transpose, true, program);
+ a_do_transpose, a_conjugate, true, program);
if (ErrorIn(status)) { return status; }
status = PadCopyTransposeMatrix(b_one, b_two, b_ld, b_offset, b_buffer,
n_ceiled, k_ceiled, n_ceiled, 0, temp_b,
- b_do_transpose, true, program);
+ b_do_transpose, b_conjugate, true, program);
if (ErrorIn(status)) { return status; }
// Only necessary for matrix C if it used both as input and output
if (beta != static_cast<T>(0)) {
status = PadCopyTransposeMatrix(c_one, c_two, c_ld, c_offset, c_buffer,
m_ceiled, n_ceiled, m_ceiled, 0, temp_c,
- c_do_transpose, true, program);
+ c_do_transpose, false, true, program);
if (ErrorIn(status)) { return status; }
}
@@ -147,7 +151,7 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
// Runs the post-processing kernel
status = PadCopyTransposeMatrix(m_ceiled, n_ceiled, m_ceiled, 0, temp_c,
c_one, c_two, c_ld, c_offset, c_buffer,
- c_do_transpose, false, program);
+ c_do_transpose, false, false, program);
if (ErrorIn(status)) { return status; }
// Successfully finished the computation