summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_part2.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level3/xgemm_part2.opencl')
-rw-r--r--src/kernels/level3/xgemm_part2.opencl124
1 files changed, 84 insertions, 40 deletions
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index a1559b54..faf17e49 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -133,49 +133,93 @@ inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int
#endif
int idm = mg + GetGroupID0() * (MWG/VWM);
int idn = ng + GetGroupID1() * NWG;
-
- // The final multiplication with alpha and the addition with beta*C
int index = idn*(kSizeM/VWM) + idm;
+
realM result;
realM xval = cpm[ni][mi];
- realM yval = cgm[index];
- #if VWM == 1
- AXPBY(result, alpha, xval, beta, yval);
- #elif VWM == 2
- AXPBY(result.x, alpha, xval.x, beta, yval.x);
- AXPBY(result.y, alpha, xval.y, beta, yval.y);
- #elif VWM == 4
- AXPBY(result.x, alpha, xval.x, beta, yval.x);
- AXPBY(result.y, alpha, xval.y, beta, yval.y);
- AXPBY(result.z, alpha, xval.z, beta, yval.z);
- AXPBY(result.w, alpha, xval.w, beta, yval.w);
- #elif VWM == 8
- AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
- AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
- AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
- AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
- AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
- AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
- AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
- AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
- #elif VWM == 16
- AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
- AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
- AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
- AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
- AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
- AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
- AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
- AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
- AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
- AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
- AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
- AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
- AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
- AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
- AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
- AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
- #endif
+
+ // The final multiplication with alpha (in case beta == 0)
+ if (IsZero(beta)) {
+ #if VWM == 1
+ Multiply(result, alpha, xval);
+ #elif VWM == 2
+ Multiply(result.x, alpha, xval.x);
+ Multiply(result.y, alpha, xval.y);
+ #elif VWM == 4
+ Multiply(result.x, alpha, xval.x);
+ Multiply(result.y, alpha, xval.y);
+ Multiply(result.z, alpha, xval.z);
+ Multiply(result.w, alpha, xval.w);
+ #elif VWM == 8
+ Multiply(result.s0, alpha, xval.s0);
+ Multiply(result.s1, alpha, xval.s1);
+ Multiply(result.s2, alpha, xval.s2);
+ Multiply(result.s3, alpha, xval.s3);
+ Multiply(result.s4, alpha, xval.s4);
+ Multiply(result.s5, alpha, xval.s5);
+ Multiply(result.s6, alpha, xval.s6);
+ Multiply(result.s7, alpha, xval.s7);
+ #elif VWM == 16
+ Multiply(result.s0, alpha, xval.s0);
+ Multiply(result.s1, alpha, xval.s1);
+ Multiply(result.s2, alpha, xval.s2);
+ Multiply(result.s3, alpha, xval.s3);
+ Multiply(result.s4, alpha, xval.s4);
+ Multiply(result.s5, alpha, xval.s5);
+ Multiply(result.s6, alpha, xval.s6);
+ Multiply(result.s7, alpha, xval.s7);
+ Multiply(result.s8, alpha, xval.s8);
+ Multiply(result.s9, alpha, xval.s9);
+ Multiply(result.sA, alpha, xval.sA);
+ Multiply(result.sB, alpha, xval.sB);
+ Multiply(result.sC, alpha, xval.sC);
+ Multiply(result.sD, alpha, xval.sD);
+ Multiply(result.sE, alpha, xval.sE);
+ Multiply(result.sF, alpha, xval.sF);
+ #endif
+ }
+
+ // The final multiplication with alpha and the addition with beta*C
+ else {
+ realM yval = cgm[index];
+ #if VWM == 1
+ AXPBY(result, alpha, xval, beta, yval);
+ #elif VWM == 2
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ #elif VWM == 4
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ AXPBY(result.z, alpha, xval.z, beta, yval.z);
+ AXPBY(result.w, alpha, xval.w, beta, yval.w);
+ #elif VWM == 8
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ #elif VWM == 16
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
+ AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
+ AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
+ AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
+ AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
+ AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
+ AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
+ AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
+ #endif
+ }
cgm[index] = result;
}
}