summaryrefslogtreecommitdiff
path: root/src/kernels/level2/xgemv_fast.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level2/xgemv_fast.opencl')
-rw-r--r--src/kernels/level2/xgemv_fast.opencl137
1 files changed, 68 insertions, 69 deletions
diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl
index 1127a0b6..1d34de96 100644
--- a/src/kernels/level2/xgemv_fast.opencl
+++ b/src/kernels/level2/xgemv_fast.opencl
@@ -38,7 +38,7 @@ R"(
#define WGS3 64 // The local work-group size
#endif
#ifndef WPT3
- #define WPT3 1 // The amount of work-per-thread
+ #define WPT3 1 // The tile-size
#endif
#ifndef VW3
#define VW3 1 // Vector width of matrix A loads
@@ -74,18 +74,12 @@ R"(
// =================================================================================================
-// Loads a vector input value (1/2)
+// Loads a vector input value
inline realVF LoadMatrixAVF(const __global realVF* restrict agm, const int x, const int y,
const int a_ld) {
return agm[a_ld*y + x];
}
-// Loads a vector input value (2/2): as before, but different data-type
-inline realVFR LoadMatrixAVFR(const __global realVFR* restrict agm, const int x, const int y,
- const int a_ld) {
- return agm[a_ld*y + x];
-}
-
// =================================================================================================
// Faster version of the kernel, assuming that:
@@ -110,7 +104,7 @@ __kernel void XgemvFast(const int m, const int n,
// Local memory for the vector X
__local real xlm[WGS2];
- // Initializes the accumulation register
+ // Initializes the accumulation registers
real acc[WPT2];
#pragma unroll
for (int w=0; w<WPT2; ++w) {
@@ -134,7 +128,7 @@ __kernel void XgemvFast(const int m, const int n,
#pragma unroll
for (int w=0; w<WPT2/VW2; ++w) {
const int gid = (WPT2/VW2)*get_global_id(0) + w;
- realVF avec = LoadMatrixAVF(agm, gid, k, a_ld/VW2);
+ realVF avec = agm[(a_ld/VW2)*k + gid];
#if VW2 == 1
MultiplyAdd(acc[VW2*w+0], xlm[kl], avec);
#elif VW2 == 2
@@ -209,72 +203,80 @@ __kernel void XgemvFastRot(const int m, const int n,
const real alpha = GetRealArg(arg_alpha);
const real beta = GetRealArg(arg_beta);
+ // Local memory to store a tile of the matrix (for coalescing)
+ __local real tile[WGS3 * WPT3];
+ const int lid = get_local_id(0);
+ const int lid_mod = lid % WPT3;
+ const int lid_div = lid / WPT3;
+
// Local memory for the vector X
- __local real xlm[WGS3];
+ __local real xlm[WPT3];
// Initializes the accumulation register
- real acc[WPT3];
- #pragma unroll
- for (int w=0; w<WPT3; ++w) {
- SetToZero(acc[w]);
- }
+ real acc;
+ SetToZero(acc);
// Loops over work-group sized portions of the work
- for (int kwg=0; kwg<n; kwg+=WGS3) {
+ for (int kwg=0; kwg<n; kwg+=WPT3) {
// Loads the vector X into local memory
- const int lid = get_local_id(0);
- xlm[lid] = xgm[(kwg + lid)*x_inc + x_offset];
+ xlm[lid] = xgm[(kwg + lid) * x_inc + x_offset];
+
+ // Loads the matrix A into local memory
+ #pragma unroll
+ for (int kl=0; kl<WPT3/VW3; ++kl) {
+ const int x = (kwg/VW3) + kl;
+ const int y = get_group_id(0) * WGS3 + lid;
+ realVFR avec = agm[(a_ld/VW3) * y + x];
+ #if VW3 == 1
+ tile[(kl*VW3 + 0) * WGS3 + lid] = avec;
+ #elif VW3 == 2
+ tile[(kl*VW3 + 0) * WGS3 + lid] = avec.x;
+ tile[(kl*VW3 + 1) * WGS3 + lid] = avec.y;
+ #elif VW3 == 4
+ tile[(kl*VW3 + 0) * WGS3 + lid] = avec.x;
+ tile[(kl*VW3 + 1) * WGS3 + lid] = avec.y;
+ tile[(kl*VW3 + 2) * WGS3 + lid] = avec.z;
+ tile[(kl*VW3 + 3) * WGS3 + lid] = avec.w;
+ #elif VW3 == 8
+ tile[(kl*VW3 + 0) * WGS3 + lid] = avec.s0;
+ tile[(kl*VW3 + 1) * WGS3 + lid] = avec.s1;
+ tile[(kl*VW3 + 2) * WGS3 + lid] = avec.s2;
+ tile[(kl*VW3 + 3) * WGS3 + lid] = avec.s3;
+ tile[(kl*VW3 + 4) * WGS3 + lid] = avec.s4;
+ tile[(kl*VW3 + 5) * WGS3 + lid] = avec.s5;
+ tile[(kl*VW3 + 6) * WGS3 + lid] = avec.s6;
+ tile[(kl*VW3 + 7) * WGS3 + lid] = avec.s7;
+ #elif VW3 == 16
+ tile[(kl*VW3 + 0) * WGS3 + lid] = avec.s0;
+ tile[(kl*VW3 + 1) * WGS3 + lid] = avec.s1;
+ tile[(kl*VW3 + 2) * WGS3 + lid] = avec.s2;
+ tile[(kl*VW3 + 3) * WGS3 + lid] = avec.s3;
+ tile[(kl*VW3 + 4) * WGS3 + lid] = avec.s4;
+ tile[(kl*VW3 + 5) * WGS3 + lid] = avec.s5;
+ tile[(kl*VW3 + 6) * WGS3 + lid] = avec.s6;
+ tile[(kl*VW3 + 7) * WGS3 + lid] = avec.s7;
+ tile[(kl*VW3 + 8) * WGS3 + lid] = avec.s8;
+ tile[(kl*VW3 + 9) * WGS3 + lid] = avec.s9;
+ tile[(kl*VW3 + 10) * WGS3 + lid] = avec.sA;
+ tile[(kl*VW3 + 11) * WGS3 + lid] = avec.sB;
+ tile[(kl*VW3 + 12) * WGS3 + lid] = avec.sC;
+ tile[(kl*VW3 + 13) * WGS3 + lid] = avec.sD;
+ tile[(kl*VW3 + 14) * WGS3 + lid] = avec.sE;
+ tile[(kl*VW3 + 15) * WGS3 + lid] = avec.sF;
+ #endif
+ }
// Synchronizes all threads in a workgroup
barrier(CLK_LOCAL_MEM_FENCE);
// The multiply-add function (rotated)
#pragma unroll
- for (int kl=0; kl<WGS3/VW3; ++kl) {
- const int k = (kwg/VW3) + kl;
- #pragma unroll
- for (int w=0; w<WPT3; ++w) {
- const int gid = WPT3*get_global_id(0) + w;
- realVFR avec = LoadMatrixAVFR(agm, k, gid, a_ld/VW3);
- #if VW3 == 1
- MultiplyAdd(acc[w], xlm[VW3*kl+0], avec);
- #elif VW3 == 2
- MultiplyAdd(acc[w], xlm[VW3*kl+0], avec.x);
- MultiplyAdd(acc[w], xlm[VW3*kl+1], avec.y);
- #elif VW3 == 4
- MultiplyAdd(acc[w], xlm[VW3*kl+0], avec.x);
- MultiplyAdd(acc[w], xlm[VW3*kl+1], avec.y);
- MultiplyAdd(acc[w], xlm[VW3*kl+2], avec.z);
- MultiplyAdd(acc[w], xlm[VW3*kl+3], avec.w);
- #elif VW3 == 8
- MultiplyAdd(acc[w], xlm[VW3*kl+0], avec.s0);
- MultiplyAdd(acc[w], xlm[VW3*kl+1], avec.s1);
- MultiplyAdd(acc[w], xlm[VW3*kl+2], avec.s2);
- MultiplyAdd(acc[w], xlm[VW3*kl+3], avec.s3);
- MultiplyAdd(acc[w], xlm[VW3*kl+4], avec.s4);
- MultiplyAdd(acc[w], xlm[VW3*kl+5], avec.s5);
- MultiplyAdd(acc[w], xlm[VW3*kl+6], avec.s6);
- MultiplyAdd(acc[w], xlm[VW3*kl+7], avec.s7);
- #elif VW3 == 16
- MultiplyAdd(acc[w], xlm[VW3*kl+0], avec.s0);
- MultiplyAdd(acc[w], xlm[VW3*kl+1], avec.s1);
- MultiplyAdd(acc[w], xlm[VW3*kl+2], avec.s2);
- MultiplyAdd(acc[w], xlm[VW3*kl+3], avec.s3);
- MultiplyAdd(acc[w], xlm[VW3*kl+4], avec.s4);
- MultiplyAdd(acc[w], xlm[VW3*kl+5], avec.s5);
- MultiplyAdd(acc[w], xlm[VW3*kl+6], avec.s6);
- MultiplyAdd(acc[w], xlm[VW3*kl+7], avec.s7);
- MultiplyAdd(acc[w], xlm[VW3*kl+8], avec.s8);
- MultiplyAdd(acc[w], xlm[VW3*kl+9], avec.s9);
- MultiplyAdd(acc[w], xlm[VW3*kl+10], avec.sA);
- MultiplyAdd(acc[w], xlm[VW3*kl+11], avec.sB);
- MultiplyAdd(acc[w], xlm[VW3*kl+12], avec.sC);
- MultiplyAdd(acc[w], xlm[VW3*kl+13], avec.sD);
- MultiplyAdd(acc[w], xlm[VW3*kl+14], avec.sE);
- MultiplyAdd(acc[w], xlm[VW3*kl+15], avec.sF);
- #endif
- }
+ for (int kl=0; kl<WPT3; ++kl) {
+ const int k = kl * (WGS3/WPT3) + lid_div;
+ real aval = tile[k * WPT3 + lid_mod];
+ real xval = xlm[kl];
+ MultiplyAdd(acc, xval, aval);
}
// Synchronizes all threads in a workgroup
@@ -282,12 +284,9 @@ __kernel void XgemvFastRot(const int m, const int n,
}
// Stores the final result
- #pragma unroll
- for (int w=0; w<WPT3; ++w) {
- const int gid = WPT3*get_global_id(0) + w;
- real yval = ygm[gid*y_inc + y_offset];
- AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[w], beta, yval);
- }
+ const int gid = get_global_id(0);
+ real yval = ygm[gid * y_inc + y_offset];
+ AXPBY(ygm[gid * y_inc + y_offset], alpha, acc, beta, yval);
}
// =================================================================================================