diff options
Diffstat (limited to 'src/kernels/level2')
-rw-r--r-- | src/kernels/level2/xgemv_fast.opencl | 84 |
1 files changed, 43 insertions, 41 deletions
diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl index 1d34de96..359c3770 100644 --- a/src/kernels/level2/xgemv_fast.opencl +++ b/src/kernels/level2/xgemv_fast.opencl @@ -204,10 +204,10 @@ __kernel void XgemvFastRot(const int m, const int n, const real beta = GetRealArg(arg_beta); // Local memory to store a tile of the matrix (for coalescing) - __local real tile[WGS3 * WPT3]; + __local real tile[WPT3][WGS3]; const int lid = get_local_id(0); - const int lid_mod = lid % WPT3; - const int lid_div = lid / WPT3; + const int lid_mod = lid % (WPT3/VW3); + const int lid_div = lid / (WPT3/VW3); // Local memory for the vector X __local real xlm[WPT3]; @@ -225,45 +225,45 @@ __kernel void XgemvFastRot(const int m, const int n, // Loads the matrix A into local memory #pragma unroll for (int kl=0; kl<WPT3/VW3; ++kl) { - const int x = (kwg/VW3) + kl; - const int y = get_group_id(0) * WGS3 + lid; + const int x = (kwg/VW3) + lid_mod; + const int y = get_group_id(0) * WGS3 + lid_div * (WPT3/VW3) + kl; realVFR avec = agm[(a_ld/VW3) * y + x]; #if VW3 == 1 - tile[(kl*VW3 + 0) * WGS3 + lid] = avec; + tile[kl*VW3 + 0][lid] = avec; #elif VW3 == 2 - tile[(kl*VW3 + 0) * WGS3 + lid] = avec.x; - tile[(kl*VW3 + 1) * WGS3 + lid] = avec.y; + tile[kl*VW3 + 0][lid] = avec.x; + tile[kl*VW3 + 1][lid] = avec.y; #elif VW3 == 4 - tile[(kl*VW3 + 0) * WGS3 + lid] = avec.x; - tile[(kl*VW3 + 1) * WGS3 + lid] = avec.y; - tile[(kl*VW3 + 2) * WGS3 + lid] = avec.z; - tile[(kl*VW3 + 3) * WGS3 + lid] = avec.w; + tile[kl*VW3 + 0][lid] = avec.x; + tile[kl*VW3 + 1][lid] = avec.y; + tile[kl*VW3 + 2][lid] = avec.z; + tile[kl*VW3 + 3][lid] = avec.w; #elif VW3 == 8 - tile[(kl*VW3 + 0) * WGS3 + lid] = avec.s0; - tile[(kl*VW3 + 1) * WGS3 + lid] = avec.s1; - tile[(kl*VW3 + 2) * WGS3 + lid] = avec.s2; - tile[(kl*VW3 + 3) * WGS3 + lid] = avec.s3; - tile[(kl*VW3 + 4) * WGS3 + lid] = avec.s4; - tile[(kl*VW3 + 5) * WGS3 + lid] = avec.s5; - tile[(kl*VW3 + 6) * WGS3 + lid] = avec.s6; - tile[(kl*VW3 + 7) * WGS3 + lid] = avec.s7; + tile[kl*VW3 + 0][lid] = avec.s0; + tile[kl*VW3 + 1][lid] = avec.s1; + tile[kl*VW3 + 2][lid] = avec.s2; + tile[kl*VW3 + 3][lid] = avec.s3; + tile[kl*VW3 + 4][lid] = avec.s4; + tile[kl*VW3 + 5][lid] = avec.s5; + tile[kl*VW3 + 6][lid] = avec.s6; + tile[kl*VW3 + 7][lid] = avec.s7; #elif VW3 == 16 - tile[(kl*VW3 + 0) * WGS3 + lid] = avec.s0; - tile[(kl*VW3 + 1) * WGS3 + lid] = avec.s1; - tile[(kl*VW3 + 2) * WGS3 + lid] = avec.s2; - tile[(kl*VW3 + 3) * WGS3 + lid] = avec.s3; - tile[(kl*VW3 + 4) * WGS3 + lid] = avec.s4; - tile[(kl*VW3 + 5) * WGS3 + lid] = avec.s5; - tile[(kl*VW3 + 6) * WGS3 + lid] = avec.s6; - tile[(kl*VW3 + 7) * WGS3 + lid] = avec.s7; - tile[(kl*VW3 + 8) * WGS3 + lid] = avec.s8; - tile[(kl*VW3 + 9) * WGS3 + lid] = avec.s9; - tile[(kl*VW3 + 10) * WGS3 + lid] = avec.sA; - tile[(kl*VW3 + 11) * WGS3 + lid] = avec.sB; - tile[(kl*VW3 + 12) * WGS3 + lid] = avec.sC; - tile[(kl*VW3 + 13) * WGS3 + lid] = avec.sD; - tile[(kl*VW3 + 14) * WGS3 + lid] = avec.sE; - tile[(kl*VW3 + 15) * WGS3 + lid] = avec.sF; + tile[kl*VW3 + 0][lid] = avec.s0; + tile[kl*VW3 + 1][lid] = avec.s1; + tile[kl*VW3 + 2][lid] = avec.s2; + tile[kl*VW3 + 3][lid] = avec.s3; + tile[kl*VW3 + 4][lid] = avec.s4; + tile[kl*VW3 + 5][lid] = avec.s5; + tile[kl*VW3 + 6][lid] = avec.s6; + tile[kl*VW3 + 7][lid] = avec.s7; + tile[kl*VW3 + 8][lid] = avec.s8; + tile[kl*VW3 + 9][lid] = avec.s9; + tile[kl*VW3 + 10][lid] = avec.sA; + tile[kl*VW3 + 11][lid] = avec.sB; + tile[kl*VW3 + 12][lid] = avec.sC; + tile[kl*VW3 + 13][lid] = avec.sD; + tile[kl*VW3 + 14][lid] = avec.sE; + tile[kl*VW3 + 15][lid] = avec.sF; #endif } @@ -272,11 +272,13 @@ __kernel void XgemvFastRot(const int m, const int n, // The multiply-add function (rotated) #pragma unroll - for (int kl=0; kl<WPT3; ++kl) { - const int k = kl * (WGS3/WPT3) + lid_div; - real aval = tile[k * WPT3 + lid_mod]; - real xval = xlm[kl]; - MultiplyAdd(acc, xval, aval); + for (int kl=0; kl<WPT3/VW3; ++kl) { + #pragma unroll + for (int v=0; v<VW3; ++v) { + real aval = tile[lid_mod*VW3 + v][lid_div * (WPT3/VW3) + kl]; + real xval = xlm[kl*VW3 + v]; + MultiplyAdd(acc, xval, aval); + } } // Synchronizes all threads in a workgroup |