diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-06-16 08:42:52 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-06-16 08:42:52 +0200 |
commit | 7e176ccac9779bd9929543127108593e0fd3b429 (patch) | |
tree | 8866c97c0b5862459f7af96a51f847ecb6e759b0 /src | |
parent | d7a0d970e088c85252740c1be591204bd6407cde (diff) |
Added support for conjugate transpose in GEMV
Diffstat (limited to 'src')
-rw-r--r-- | src/kernels/xgemv.opencl | 27 | ||||
-rw-r--r-- | src/routines/xgemv.cc | 10 | ||||
-rw-r--r-- | src/tuning/xgemv.cc | 1 |
3 files changed, 28 insertions, 10 deletions
diff --git a/src/kernels/xgemv.opencl b/src/kernels/xgemv.opencl index 5ea70e0d..4bb69090 100644 --- a/src/kernels/xgemv.opencl +++ b/src/kernels/xgemv.opencl @@ -58,7 +58,8 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, const int a_rotated, const __global real* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc) { + __global real* ygm, const int y_offset, const int y_inc, + const int do_conjugate) { // Local memory for the vector X __local real xlm[WGS1]; @@ -95,14 +96,18 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, #pragma unroll for (int kl=0; kl<WGS1; ++kl) { const int k = kwg + kl; - MultiplyAdd(acc[w], xlm[kl], agm[gid + a_ld*k + a_offset]); + real value = agm[gid + a_ld*k + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xlm[kl], value); } } else { // Transposed #pragma unroll for (int kl=0; kl<WGS1; ++kl) { const int k = kwg + kl; - MultiplyAdd(acc[w], xlm[kl], agm[k + a_ld*gid + a_offset]); + real value = agm[k + a_ld*gid + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xlm[kl], value); } } } @@ -122,13 +127,17 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, if (a_rotated == 0) { // Not rotated #pragma unroll for (int k=n_floor; k<n; ++k) { - MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], agm[gid + a_ld*k + a_offset]); + real value = agm[gid + a_ld*k + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value); } } else { // Transposed #pragma unroll for (int k=n_floor; k<n; ++k) { - MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], agm[k + a_ld*gid + a_offset]); + real value = agm[k + a_ld*gid + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value); } } @@ -159,12 +168,14 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, // --> 'a_offset' is 0 // --> 'a_ld' is a multiple of VW2 // --> 'a_rotated' is 0 +// --> 'do_conjugate' is 0 __attribute__((reqd_work_group_size(WGS2, 1, 1))) __kernel void XgemvFast(const int m, const int n, const real alpha, const real beta, const int a_rotated, const __global realVF* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc) { + __global real* ygm, const int y_offset, const int y_inc, + const int do_conjugate) { // Local memory for the vector X __local real xlm[WGS2]; @@ -265,12 +276,14 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b // --> 'a_offset' is 0 // --> 'a_ld' is a multiple of VW3 // --> 'a_rotated' is 1 +// --> 'do_conjugate' is 0 __attribute__((reqd_work_group_size(WGS3, 1, 1))) __kernel void XgemvFastRot(const int m, const int n, const real alpha, const real beta, const int a_rotated, const __global realVFR* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc) { + __global real* ygm, const int y_offset, const int y_inc, + const int do_conjugate) { // Local memory for the vector X __local real xlm[WGS3]; diff --git a/src/routines/xgemv.cc b/src/routines/xgemv.cc index 9f3908f8..78071c17 100644 --- a/src/routines/xgemv.cc +++ b/src/routines/xgemv.cc @@ -54,13 +54,16 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose, auto a_two = (a_altlayout) ? m : n; // Swap m and n if the matrix is transposed - auto a_transposed = (a_transpose == Transpose::kYes); + auto a_transposed = (a_transpose != Transpose::kNo); auto m_real = (a_transposed) ? n : m; auto n_real = (a_transposed) ? m : n; // Determines whether the kernel needs to perform rotated access ('^' is the XOR operator) auto a_rotated = a_transposed ^ a_altlayout; + // In case of complex data-types, the transpose can also become a conjugate transpose + auto a_conjugate = (a_transpose == Transpose::kConjugate); + // Tests the matrix and the vectors for validity auto status = TestMatrixA(a_one, a_two, a_buffer, a_offset, a_ld, sizeof(T)); if (ErrorIn(status)) { return status; } @@ -70,11 +73,11 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose, if (ErrorIn(status)) { return status; } // Determines whether or not the fast-version can be used - bool use_fast_kernel = (a_offset == 0) && (a_rotated == 0) && + bool use_fast_kernel = (a_offset == 0) && (a_rotated == 0) && (a_conjugate == 0) && IsMultiple(m, db_["WGS2"]*db_["WPT2"]) && IsMultiple(n, db_["WGS2"]) && IsMultiple(a_ld, db_["VW2"]); - bool use_fast_kernel_rot = (a_offset == 0) && (a_rotated == 1) && + bool use_fast_kernel_rot = (a_offset == 0) && (a_rotated == 1) && (a_conjugate == 0) && IsMultiple(m, db_["WGS3"]*db_["WPT3"]) && IsMultiple(n, db_["WGS3"]) && IsMultiple(a_ld, db_["VW3"]); @@ -115,6 +118,7 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose, kernel.SetArgument(11, y_buffer()); kernel.SetArgument(12, static_cast<int>(y_offset)); kernel.SetArgument(13, static_cast<int>(y_inc)); + kernel.SetArgument(14, static_cast<int>(a_conjugate)); // Launches the kernel auto global = std::vector<size_t>{global_size}; diff --git a/src/tuning/xgemv.cc b/src/tuning/xgemv.cc index dccd250c..48df6f25 100644 --- a/src/tuning/xgemv.cc +++ b/src/tuning/xgemv.cc @@ -90,6 +90,7 @@ void XgemvTune(const Arguments<T> &args, const size_t variation, tuner.AddArgumentOutput(y_vec); tuner.AddArgumentScalar(0); tuner.AddArgumentScalar(1); + tuner.AddArgumentScalar(0); // Conjugate transpose } // ================================================================================================= |