diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-08-03 07:37:14 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-08-03 07:37:14 +0200 |
commit | d1a7cf18ecfee1879d00e3a19ce129ee058dd84f (patch) | |
tree | e54462deb1fafcdb65f5d573ad1a788d87f08343 /src | |
parent | fc7cd434e15b51ff8d39a0fcc8acba4b861ffd18 (diff) |
Abstracted loading of matrix A for GEMV kernel
Diffstat (limited to 'src')
-rw-r--r-- | src/kernels/xgemv.opencl | 156 |
1 files changed, 94 insertions, 62 deletions
diff --git a/src/kernels/xgemv.opencl b/src/kernels/xgemv.opencl index 65061717..5bbf69b9 100644 --- a/src/kernels/xgemv.opencl +++ b/src/kernels/xgemv.opencl @@ -52,6 +52,63 @@ R"( // ================================================================================================= +// Data-widths for the 'fast' kernel +#if VW2 == 1 + typedef real realVF; +#elif VW2 == 2 + typedef real2 realVF; +#elif VW2 == 4 + typedef real4 realVF; +#elif VW2 == 8 + typedef real8 realVF; +#elif VW2 == 16 + typedef real16 realVF; +#endif + +// Data-widths for the 'fast' kernel with rotated matrix +#if VW3 == 1 + typedef real realVFR; +#elif VW3 == 2 + typedef real2 realVFR; +#elif VW3 == 4 + typedef real4 realVFR; +#elif VW3 == 8 + typedef real8 realVFR; +#elif VW3 == 16 + typedef real16 realVFR; +#endif + +// ================================================================================================= +// Defines how to load the input matrix in case of a symmetric matrix +#if defined(ROUTINE_SYMV) + +// ================================================================================================= +// Defines how to load the input matrix in case of a hermetian matrix +#elif defined(ROUTINE_HEMV) + +// ================================================================================================= +// Defines how to load the input matrix in the regular case +#else + +// Loads a scalar input value +inline real LoadMatrixA(const __global real* restrict agm, const int x, const int y, + const int a_ld, const int a_offset) { + return agm[x + a_ld*y + a_offset]; +} +// Loads a vector input value (1/2) +inline realVF LoadMatrixAVF(const __global realVF* restrict agm, const int x, const int y, + const int a_ld) { + return agm[x + a_ld*y]; +} +// Loads a vector input value (2/2): as before, but different data-type +inline realVFR LoadMatrixAVFR(const __global realVFR* restrict agm, const int x, const int y, + const int a_ld) { + return agm[x + a_ld*y]; +} + +#endif +// ================================================================================================= + // Full version of the kernel __attribute__((reqd_work_group_size(WGS1, 1, 1))) __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, @@ -96,7 +153,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, #pragma unroll for (int kl=0; kl<WGS1; ++kl) { const int k = kwg + kl; - real value = agm[gid + a_ld*k + a_offset]; + real value = LoadMatrixA(agm, gid, k, a_ld, a_offset); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } MultiplyAdd(acc[w], xlm[kl], value); } @@ -105,7 +162,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, #pragma unroll for (int kl=0; kl<WGS1; ++kl) { const int k = kwg + kl; - real value = agm[k + a_ld*gid + a_offset]; + real value = LoadMatrixA(agm, k, gid, a_ld, a_offset); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } MultiplyAdd(acc[w], xlm[kl], value); } @@ -127,7 +184,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, if (a_rotated == 0) { // Not rotated #pragma unroll for (int k=n_floor; k<n; ++k) { - real value = agm[gid + a_ld*k + a_offset]; + real value = LoadMatrixA(agm, gid, k, a_ld, a_offset); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value); } @@ -135,7 +192,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, else { // Transposed #pragma unroll for (int k=n_floor; k<n; ++k) { - real value = agm[k + a_ld*gid + a_offset]; + real value = LoadMatrixA(agm, k, gid, a_ld, a_offset); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value); } @@ -150,19 +207,6 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, // ================================================================================================= -// Data-widths for the 'fast' kernel -#if VW2 == 1 - typedef real realVF; -#elif VW2 == 2 - typedef real2 realVF; -#elif VW2 == 4 - typedef real4 realVF; -#elif VW2 == 8 - typedef real8 realVF; -#elif VW2 == 16 - typedef real16 realVF; -#endif - // Faster version of the kernel, assuming that: // --> 'm' and 'n' are multiples of WGS2 // --> 'a_offset' is 0 @@ -203,42 +247,43 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b #pragma unroll for (int w=0; w<WPT2/VW2; ++w) { const int gid = (WPT2/VW2)*get_global_id(0) + w; + realVF avec = LoadMatrixAVF(agm, gid, k, a_ld/VW2); #if VW2 == 1 - MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k]); + MultiplyAdd(acc[VW2*w+0], xlm[kl], avec); #elif VW2 == 2 - MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].x); - MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].y); + MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x); + MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y); #elif VW2 == 4 - MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].x); - MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].y); - MultiplyAdd(acc[VW2*w+2], xlm[kl], agm[gid + (a_ld/VW2)*k].z); - MultiplyAdd(acc[VW2*w+3], xlm[kl], agm[gid + (a_ld/VW2)*k].w); + MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x); + MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y); + MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.z); + MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.w); #elif VW2 == 8 - MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].s0); - MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].s1); - MultiplyAdd(acc[VW2*w+2], xlm[kl], agm[gid + (a_ld/VW2)*k].s2); - MultiplyAdd(acc[VW2*w+3], xlm[kl], agm[gid + (a_ld/VW2)*k].s3); - MultiplyAdd(acc[VW2*w+4], xlm[kl], agm[gid + (a_ld/VW2)*k].s4); - MultiplyAdd(acc[VW2*w+5], xlm[kl], agm[gid + (a_ld/VW2)*k].s5); - MultiplyAdd(acc[VW2*w+6], xlm[kl], agm[gid + (a_ld/VW2)*k].s6); - MultiplyAdd(acc[VW2*w+7], xlm[kl], agm[gid + (a_ld/VW2)*k].s7); + MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0); + MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1); + MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2); + MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3); + MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4); + MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5); + MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6); + MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7); #elif VW2 == 16 - MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].s0); - MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].s1); - MultiplyAdd(acc[VW2*w+2], xlm[kl], agm[gid + (a_ld/VW2)*k].s2); - MultiplyAdd(acc[VW2*w+3], xlm[kl], agm[gid + (a_ld/VW2)*k].s3); - MultiplyAdd(acc[VW2*w+4], xlm[kl], agm[gid + (a_ld/VW2)*k].s4); - MultiplyAdd(acc[VW2*w+5], xlm[kl], agm[gid + (a_ld/VW2)*k].s5); - MultiplyAdd(acc[VW2*w+6], xlm[kl], agm[gid + (a_ld/VW2)*k].s6); - MultiplyAdd(acc[VW2*w+7], xlm[kl], agm[gid + (a_ld/VW2)*k].s7); - MultiplyAdd(acc[VW2*w+8], xlm[kl], agm[gid + (a_ld/VW2)*k].s8); - MultiplyAdd(acc[VW2*w+9], xlm[kl], agm[gid + (a_ld/VW2)*k].s9); - MultiplyAdd(acc[VW2*w+10], xlm[kl], agm[gid + (a_ld/VW2)*k].sA); - MultiplyAdd(acc[VW2*w+11], xlm[kl], agm[gid + (a_ld/VW2)*k].sB); - MultiplyAdd(acc[VW2*w+12], xlm[kl], agm[gid + (a_ld/VW2)*k].sC); - MultiplyAdd(acc[VW2*w+13], xlm[kl], agm[gid + (a_ld/VW2)*k].sD); - MultiplyAdd(acc[VW2*w+14], xlm[kl], agm[gid + (a_ld/VW2)*k].sE); - MultiplyAdd(acc[VW2*w+15], xlm[kl], agm[gid + (a_ld/VW2)*k].sF); + MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0); + MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1); + MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2); + MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3); + MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4); + MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5); + MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6); + MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7); + MultiplyAdd(acc[VW2*w+8], xlm[kl], avec.s8); + MultiplyAdd(acc[VW2*w+9], xlm[kl], avec.s9); + MultiplyAdd(acc[VW2*w+10], xlm[kl], avec.sA); + MultiplyAdd(acc[VW2*w+11], xlm[kl], avec.sB); + MultiplyAdd(acc[VW2*w+12], xlm[kl], avec.sC); + MultiplyAdd(acc[VW2*w+13], xlm[kl], avec.sD); + MultiplyAdd(acc[VW2*w+14], xlm[kl], avec.sE); + MultiplyAdd(acc[VW2*w+15], xlm[kl], avec.sF); #endif } } @@ -258,19 +303,6 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b // ================================================================================================= -// Data-widths for the 'fast' kernel with rotated matrix -#if VW3 == 1 - typedef real realVFR; -#elif VW3 == 2 - typedef real2 realVFR; -#elif VW3 == 4 - typedef real4 realVFR; -#elif VW3 == 8 - typedef real8 realVFR; -#elif VW3 == 16 - typedef real16 realVFR; -#endif - // Faster version of the kernel, assuming that: // --> 'm' and 'n' are multiples of WGS3 // --> 'a_offset' is 0 @@ -311,7 +343,7 @@ __kernel void XgemvFastRot(const int m, const int n, const real alpha, const rea #pragma unroll for (int w=0; w<WPT3; ++w) { const int gid = WPT3*get_global_id(0) + w; - realVFR avec = agm[k + (a_ld/VW3)*gid]; + realVFR avec = LoadMatrixAVFR(agm, k, gid, a_ld/VW3); #if VW3 == 1 MultiplyAdd(acc[w], xlm[VW3*kl+0], avec); #elif VW3 == 2 |