summaryrefslogtreecommitdiff
path: root/src/kernels/level2/xgemv_fast.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level2/xgemv_fast.opencl')
-rw-r--r--src/kernels/level2/xgemv_fast.opencl84
1 files changed, 43 insertions, 41 deletions
diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl
index 1d34de96..359c3770 100644
--- a/src/kernels/level2/xgemv_fast.opencl
+++ b/src/kernels/level2/xgemv_fast.opencl
@@ -204,10 +204,10 @@ __kernel void XgemvFastRot(const int m, const int n,
const real beta = GetRealArg(arg_beta);
// Local memory to store a tile of the matrix (for coalescing)
- __local real tile[WGS3 * WPT3];
+ __local real tile[WPT3][WGS3];
const int lid = get_local_id(0);
- const int lid_mod = lid % WPT3;
- const int lid_div = lid / WPT3;
+ const int lid_mod = lid % (WPT3/VW3);
+ const int lid_div = lid / (WPT3/VW3);
// Local memory for the vector X
__local real xlm[WPT3];
@@ -225,45 +225,45 @@ __kernel void XgemvFastRot(const int m, const int n,
// Loads the matrix A into local memory
#pragma unroll
for (int kl=0; kl<WPT3/VW3; ++kl) {
- const int x = (kwg/VW3) + kl;
- const int y = get_group_id(0) * WGS3 + lid;
+ const int x = (kwg/VW3) + lid_mod;
+ const int y = get_group_id(0) * WGS3 + lid_div * (WPT3/VW3) + kl;
realVFR avec = agm[(a_ld/VW3) * y + x];
#if VW3 == 1
- tile[(kl*VW3 + 0) * WGS3 + lid] = avec;
+ tile[kl*VW3 + 0][lid] = avec;
#elif VW3 == 2
- tile[(kl*VW3 + 0) * WGS3 + lid] = avec.x;
- tile[(kl*VW3 + 1) * WGS3 + lid] = avec.y;
+ tile[kl*VW3 + 0][lid] = avec.x;
+ tile[kl*VW3 + 1][lid] = avec.y;
#elif VW3 == 4
- tile[(kl*VW3 + 0) * WGS3 + lid] = avec.x;
- tile[(kl*VW3 + 1) * WGS3 + lid] = avec.y;
- tile[(kl*VW3 + 2) * WGS3 + lid] = avec.z;
- tile[(kl*VW3 + 3) * WGS3 + lid] = avec.w;
+ tile[kl*VW3 + 0][lid] = avec.x;
+ tile[kl*VW3 + 1][lid] = avec.y;
+ tile[kl*VW3 + 2][lid] = avec.z;
+ tile[kl*VW3 + 3][lid] = avec.w;
#elif VW3 == 8
- tile[(kl*VW3 + 0) * WGS3 + lid] = avec.s0;
- tile[(kl*VW3 + 1) * WGS3 + lid] = avec.s1;
- tile[(kl*VW3 + 2) * WGS3 + lid] = avec.s2;
- tile[(kl*VW3 + 3) * WGS3 + lid] = avec.s3;
- tile[(kl*VW3 + 4) * WGS3 + lid] = avec.s4;
- tile[(kl*VW3 + 5) * WGS3 + lid] = avec.s5;
- tile[(kl*VW3 + 6) * WGS3 + lid] = avec.s6;
- tile[(kl*VW3 + 7) * WGS3 + lid] = avec.s7;
+ tile[kl*VW3 + 0][lid] = avec.s0;
+ tile[kl*VW3 + 1][lid] = avec.s1;
+ tile[kl*VW3 + 2][lid] = avec.s2;
+ tile[kl*VW3 + 3][lid] = avec.s3;
+ tile[kl*VW3 + 4][lid] = avec.s4;
+ tile[kl*VW3 + 5][lid] = avec.s5;
+ tile[kl*VW3 + 6][lid] = avec.s6;
+ tile[kl*VW3 + 7][lid] = avec.s7;
#elif VW3 == 16
- tile[(kl*VW3 + 0) * WGS3 + lid] = avec.s0;
- tile[(kl*VW3 + 1) * WGS3 + lid] = avec.s1;
- tile[(kl*VW3 + 2) * WGS3 + lid] = avec.s2;
- tile[(kl*VW3 + 3) * WGS3 + lid] = avec.s3;
- tile[(kl*VW3 + 4) * WGS3 + lid] = avec.s4;
- tile[(kl*VW3 + 5) * WGS3 + lid] = avec.s5;
- tile[(kl*VW3 + 6) * WGS3 + lid] = avec.s6;
- tile[(kl*VW3 + 7) * WGS3 + lid] = avec.s7;
- tile[(kl*VW3 + 8) * WGS3 + lid] = avec.s8;
- tile[(kl*VW3 + 9) * WGS3 + lid] = avec.s9;
- tile[(kl*VW3 + 10) * WGS3 + lid] = avec.sA;
- tile[(kl*VW3 + 11) * WGS3 + lid] = avec.sB;
- tile[(kl*VW3 + 12) * WGS3 + lid] = avec.sC;
- tile[(kl*VW3 + 13) * WGS3 + lid] = avec.sD;
- tile[(kl*VW3 + 14) * WGS3 + lid] = avec.sE;
- tile[(kl*VW3 + 15) * WGS3 + lid] = avec.sF;
+ tile[kl*VW3 + 0][lid] = avec.s0;
+ tile[kl*VW3 + 1][lid] = avec.s1;
+ tile[kl*VW3 + 2][lid] = avec.s2;
+ tile[kl*VW3 + 3][lid] = avec.s3;
+ tile[kl*VW3 + 4][lid] = avec.s4;
+ tile[kl*VW3 + 5][lid] = avec.s5;
+ tile[kl*VW3 + 6][lid] = avec.s6;
+ tile[kl*VW3 + 7][lid] = avec.s7;
+ tile[kl*VW3 + 8][lid] = avec.s8;
+ tile[kl*VW3 + 9][lid] = avec.s9;
+ tile[kl*VW3 + 10][lid] = avec.sA;
+ tile[kl*VW3 + 11][lid] = avec.sB;
+ tile[kl*VW3 + 12][lid] = avec.sC;
+ tile[kl*VW3 + 13][lid] = avec.sD;
+ tile[kl*VW3 + 14][lid] = avec.sE;
+ tile[kl*VW3 + 15][lid] = avec.sF;
#endif
}
@@ -272,11 +272,13 @@ __kernel void XgemvFastRot(const int m, const int n,
// The multiply-add function (rotated)
#pragma unroll
- for (int kl=0; kl<WPT3; ++kl) {
- const int k = kl * (WGS3/WPT3) + lid_div;
- real aval = tile[k * WPT3 + lid_mod];
- real xval = xlm[kl];
- MultiplyAdd(acc, xval, aval);
+ for (int kl=0; kl<WPT3/VW3; ++kl) {
+ #pragma unroll
+ for (int v=0; v<VW3; ++v) {
+ real aval = tile[lid_mod*VW3 + v][lid_div * (WPT3/VW3) + kl];
+ real xval = xlm[kl*VW3 + v];
+ MultiplyAdd(acc, xval, aval);
+ }
}
// Synchronizes all threads in a workgroup