diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-06-16 08:42:52 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-06-16 08:42:52 +0200 |
commit | 7e176ccac9779bd9929543127108593e0fd3b429 (patch) | |
tree | 8866c97c0b5862459f7af96a51f847ecb6e759b0 /src/kernels | |
parent | d7a0d970e088c85252740c1be591204bd6407cde (diff) |
Added support for conjugate transpose in GEMV
Diffstat (limited to 'src/kernels')
-rw-r--r-- | src/kernels/xgemv.opencl | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/src/kernels/xgemv.opencl b/src/kernels/xgemv.opencl index 5ea70e0d..4bb69090 100644 --- a/src/kernels/xgemv.opencl +++ b/src/kernels/xgemv.opencl @@ -58,7 +58,8 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, const int a_rotated, const __global real* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc) { + __global real* ygm, const int y_offset, const int y_inc, + const int do_conjugate) { // Local memory for the vector X __local real xlm[WGS1]; @@ -95,14 +96,18 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, #pragma unroll for (int kl=0; kl<WGS1; ++kl) { const int k = kwg + kl; - MultiplyAdd(acc[w], xlm[kl], agm[gid + a_ld*k + a_offset]); + real value = agm[gid + a_ld*k + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xlm[kl], value); } } else { // Transposed #pragma unroll for (int kl=0; kl<WGS1; ++kl) { const int k = kwg + kl; - MultiplyAdd(acc[w], xlm[kl], agm[k + a_ld*gid + a_offset]); + real value = agm[k + a_ld*gid + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xlm[kl], value); } } } @@ -122,13 +127,17 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, if (a_rotated == 0) { // Not rotated #pragma unroll for (int k=n_floor; k<n; ++k) { - MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], agm[gid + a_ld*k + a_offset]); + real value = agm[gid + a_ld*k + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value); } } else { // Transposed #pragma unroll for (int k=n_floor; k<n; ++k) { - MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], agm[k + a_ld*gid + a_offset]); + real value = agm[k + a_ld*gid + a_offset]; + if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } + MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value); } } @@ -159,12 +168,14 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, // --> 'a_offset' is 0 // --> 'a_ld' is a multiple of VW2 // --> 'a_rotated' is 0 +// --> 'do_conjugate' is 0 __attribute__((reqd_work_group_size(WGS2, 1, 1))) __kernel void XgemvFast(const int m, const int n, const real alpha, const real beta, const int a_rotated, const __global realVF* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc) { + __global real* ygm, const int y_offset, const int y_inc, + const int do_conjugate) { // Local memory for the vector X __local real xlm[WGS2]; @@ -265,12 +276,14 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b // --> 'a_offset' is 0 // --> 'a_ld' is a multiple of VW3 // --> 'a_rotated' is 1 +// --> 'do_conjugate' is 0 __attribute__((reqd_work_group_size(WGS3, 1, 1))) __kernel void XgemvFastRot(const int m, const int n, const real alpha, const real beta, const int a_rotated, const __global realVFR* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc) { + __global real* ygm, const int y_offset, const int y_inc, + const int do_conjugate) { // Local memory for the vector X __local real xlm[WGS3]; |