summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-08-03 07:37:14 +0200
committerCNugteren <web@cedricnugteren.nl>2015-08-03 07:37:14 +0200
commitd1a7cf18ecfee1879d00e3a19ce129ee058dd84f (patch)
treee54462deb1fafcdb65f5d573ad1a788d87f08343 /src
parentfc7cd434e15b51ff8d39a0fcc8acba4b861ffd18 (diff)
Abstracted loading of matrix A for GEMV kernel
Diffstat (limited to 'src')
-rw-r--r--src/kernels/xgemv.opencl156
1 files changed, 94 insertions, 62 deletions
diff --git a/src/kernels/xgemv.opencl b/src/kernels/xgemv.opencl
index 65061717..5bbf69b9 100644
--- a/src/kernels/xgemv.opencl
+++ b/src/kernels/xgemv.opencl
@@ -52,6 +52,63 @@ R"(
// =================================================================================================
+// Data-widths for the 'fast' kernel
+#if VW2 == 1
+ typedef real realVF;
+#elif VW2 == 2
+ typedef real2 realVF;
+#elif VW2 == 4
+ typedef real4 realVF;
+#elif VW2 == 8
+ typedef real8 realVF;
+#elif VW2 == 16
+ typedef real16 realVF;
+#endif
+
+// Data-widths for the 'fast' kernel with rotated matrix
+#if VW3 == 1
+ typedef real realVFR;
+#elif VW3 == 2
+ typedef real2 realVFR;
+#elif VW3 == 4
+ typedef real4 realVFR;
+#elif VW3 == 8
+ typedef real8 realVFR;
+#elif VW3 == 16
+ typedef real16 realVFR;
+#endif
+
+// =================================================================================================
+// Defines how to load the input matrix in case of a symmetric matrix
+#if defined(ROUTINE_SYMV)
+
+// =================================================================================================
+// Defines how to load the input matrix in case of a hermetian matrix
+#elif defined(ROUTINE_HEMV)
+
+// =================================================================================================
+// Defines how to load the input matrix in the regular case
+#else
+
+// Loads a scalar input value
+inline real LoadMatrixA(const __global real* restrict agm, const int x, const int y,
+ const int a_ld, const int a_offset) {
+ return agm[x + a_ld*y + a_offset];
+}
+// Loads a vector input value (1/2)
+inline realVF LoadMatrixAVF(const __global realVF* restrict agm, const int x, const int y,
+ const int a_ld) {
+ return agm[x + a_ld*y];
+}
+// Loads a vector input value (2/2): as before, but different data-type
+inline realVFR LoadMatrixAVFR(const __global realVFR* restrict agm, const int x, const int y,
+ const int a_ld) {
+ return agm[x + a_ld*y];
+}
+
+#endif
+// =================================================================================================
+
// Full version of the kernel
__attribute__((reqd_work_group_size(WGS1, 1, 1)))
__kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
@@ -96,7 +153,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
#pragma unroll
for (int kl=0; kl<WGS1; ++kl) {
const int k = kwg + kl;
- real value = agm[gid + a_ld*k + a_offset];
+ real value = LoadMatrixA(agm, gid, k, a_ld, a_offset);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
MultiplyAdd(acc[w], xlm[kl], value);
}
@@ -105,7 +162,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
#pragma unroll
for (int kl=0; kl<WGS1; ++kl) {
const int k = kwg + kl;
- real value = agm[k + a_ld*gid + a_offset];
+ real value = LoadMatrixA(agm, k, gid, a_ld, a_offset);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
MultiplyAdd(acc[w], xlm[kl], value);
}
@@ -127,7 +184,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
if (a_rotated == 0) { // Not rotated
#pragma unroll
for (int k=n_floor; k<n; ++k) {
- real value = agm[gid + a_ld*k + a_offset];
+ real value = LoadMatrixA(agm, gid, k, a_ld, a_offset);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value);
}
@@ -135,7 +192,7 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
else { // Transposed
#pragma unroll
for (int k=n_floor; k<n; ++k) {
- real value = agm[k + a_ld*gid + a_offset];
+ real value = LoadMatrixA(agm, k, gid, a_ld, a_offset);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value);
}
@@ -150,19 +207,6 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
// =================================================================================================
-// Data-widths for the 'fast' kernel
-#if VW2 == 1
- typedef real realVF;
-#elif VW2 == 2
- typedef real2 realVF;
-#elif VW2 == 4
- typedef real4 realVF;
-#elif VW2 == 8
- typedef real8 realVF;
-#elif VW2 == 16
- typedef real16 realVF;
-#endif
-
// Faster version of the kernel, assuming that:
// --> 'm' and 'n' are multiples of WGS2
// --> 'a_offset' is 0
@@ -203,42 +247,43 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b
#pragma unroll
for (int w=0; w<WPT2/VW2; ++w) {
const int gid = (WPT2/VW2)*get_global_id(0) + w;
+ realVF avec = LoadMatrixAVF(agm, gid, k, a_ld/VW2);
#if VW2 == 1
- MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k]);
+ MultiplyAdd(acc[VW2*w+0], xlm[kl], avec);
#elif VW2 == 2
- MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].x);
- MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].y);
+ MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x);
+ MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y);
#elif VW2 == 4
- MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].x);
- MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].y);
- MultiplyAdd(acc[VW2*w+2], xlm[kl], agm[gid + (a_ld/VW2)*k].z);
- MultiplyAdd(acc[VW2*w+3], xlm[kl], agm[gid + (a_ld/VW2)*k].w);
+ MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.x);
+ MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.y);
+ MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.z);
+ MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.w);
#elif VW2 == 8
- MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].s0);
- MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].s1);
- MultiplyAdd(acc[VW2*w+2], xlm[kl], agm[gid + (a_ld/VW2)*k].s2);
- MultiplyAdd(acc[VW2*w+3], xlm[kl], agm[gid + (a_ld/VW2)*k].s3);
- MultiplyAdd(acc[VW2*w+4], xlm[kl], agm[gid + (a_ld/VW2)*k].s4);
- MultiplyAdd(acc[VW2*w+5], xlm[kl], agm[gid + (a_ld/VW2)*k].s5);
- MultiplyAdd(acc[VW2*w+6], xlm[kl], agm[gid + (a_ld/VW2)*k].s6);
- MultiplyAdd(acc[VW2*w+7], xlm[kl], agm[gid + (a_ld/VW2)*k].s7);
+ MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0);
+ MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1);
+ MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2);
+ MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3);
+ MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4);
+ MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5);
+ MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6);
+ MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7);
#elif VW2 == 16
- MultiplyAdd(acc[VW2*w+0], xlm[kl], agm[gid + (a_ld/VW2)*k].s0);
- MultiplyAdd(acc[VW2*w+1], xlm[kl], agm[gid + (a_ld/VW2)*k].s1);
- MultiplyAdd(acc[VW2*w+2], xlm[kl], agm[gid + (a_ld/VW2)*k].s2);
- MultiplyAdd(acc[VW2*w+3], xlm[kl], agm[gid + (a_ld/VW2)*k].s3);
- MultiplyAdd(acc[VW2*w+4], xlm[kl], agm[gid + (a_ld/VW2)*k].s4);
- MultiplyAdd(acc[VW2*w+5], xlm[kl], agm[gid + (a_ld/VW2)*k].s5);
- MultiplyAdd(acc[VW2*w+6], xlm[kl], agm[gid + (a_ld/VW2)*k].s6);
- MultiplyAdd(acc[VW2*w+7], xlm[kl], agm[gid + (a_ld/VW2)*k].s7);
- MultiplyAdd(acc[VW2*w+8], xlm[kl], agm[gid + (a_ld/VW2)*k].s8);
- MultiplyAdd(acc[VW2*w+9], xlm[kl], agm[gid + (a_ld/VW2)*k].s9);
- MultiplyAdd(acc[VW2*w+10], xlm[kl], agm[gid + (a_ld/VW2)*k].sA);
- MultiplyAdd(acc[VW2*w+11], xlm[kl], agm[gid + (a_ld/VW2)*k].sB);
- MultiplyAdd(acc[VW2*w+12], xlm[kl], agm[gid + (a_ld/VW2)*k].sC);
- MultiplyAdd(acc[VW2*w+13], xlm[kl], agm[gid + (a_ld/VW2)*k].sD);
- MultiplyAdd(acc[VW2*w+14], xlm[kl], agm[gid + (a_ld/VW2)*k].sE);
- MultiplyAdd(acc[VW2*w+15], xlm[kl], agm[gid + (a_ld/VW2)*k].sF);
+ MultiplyAdd(acc[VW2*w+0], xlm[kl], avec.s0);
+ MultiplyAdd(acc[VW2*w+1], xlm[kl], avec.s1);
+ MultiplyAdd(acc[VW2*w+2], xlm[kl], avec.s2);
+ MultiplyAdd(acc[VW2*w+3], xlm[kl], avec.s3);
+ MultiplyAdd(acc[VW2*w+4], xlm[kl], avec.s4);
+ MultiplyAdd(acc[VW2*w+5], xlm[kl], avec.s5);
+ MultiplyAdd(acc[VW2*w+6], xlm[kl], avec.s6);
+ MultiplyAdd(acc[VW2*w+7], xlm[kl], avec.s7);
+ MultiplyAdd(acc[VW2*w+8], xlm[kl], avec.s8);
+ MultiplyAdd(acc[VW2*w+9], xlm[kl], avec.s9);
+ MultiplyAdd(acc[VW2*w+10], xlm[kl], avec.sA);
+ MultiplyAdd(acc[VW2*w+11], xlm[kl], avec.sB);
+ MultiplyAdd(acc[VW2*w+12], xlm[kl], avec.sC);
+ MultiplyAdd(acc[VW2*w+13], xlm[kl], avec.sD);
+ MultiplyAdd(acc[VW2*w+14], xlm[kl], avec.sE);
+ MultiplyAdd(acc[VW2*w+15], xlm[kl], avec.sF);
#endif
}
}
@@ -258,19 +303,6 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b
// =================================================================================================
-// Data-widths for the 'fast' kernel with rotated matrix
-#if VW3 == 1
- typedef real realVFR;
-#elif VW3 == 2
- typedef real2 realVFR;
-#elif VW3 == 4
- typedef real4 realVFR;
-#elif VW3 == 8
- typedef real8 realVFR;
-#elif VW3 == 16
- typedef real16 realVFR;
-#endif
-
// Faster version of the kernel, assuming that:
// --> 'm' and 'n' are multiples of WGS3
// --> 'a_offset' is 0
@@ -311,7 +343,7 @@ __kernel void XgemvFastRot(const int m, const int n, const real alpha, const rea
#pragma unroll
for (int w=0; w<WPT3; ++w) {
const int gid = WPT3*get_global_id(0) + w;
- realVFR avec = agm[k + (a_ld/VW3)*gid];
+ realVFR avec = LoadMatrixAVFR(agm, k, gid, a_ld/VW3);
#if VW3 == 1
MultiplyAdd(acc[w], xlm[VW3*kl+0], avec);
#elif VW3 == 2