summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_part2.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level3/xgemm_part2.opencl')
-rw-r--r--src/kernels/level3/xgemm_part2.opencl68
1 files changed, 34 insertions, 34 deletions
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index a5507458..88100e96 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -64,48 +64,48 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva
}
// Performs the actual computation: Cpm += Apm * Bpm
-INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
+INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI*MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const realM aval = apm[_mi];
#if VWN == 1
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni]);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
#elif VWN == 2
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x);
- cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
#elif VWN == 4
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x);
- cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y);
- cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].z);
- cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].w);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
#elif VWN == 8
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].s0);
- cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].s1);
- cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].s2);
- cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].s3);
- cpm[_ni*VWN + 4][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4][_mi], aval, bpm[_ni].s4);
- cpm[_ni*VWN + 5][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5][_mi], aval, bpm[_ni].s5);
- cpm[_ni*VWN + 6][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6][_mi], aval, bpm[_ni].s6);
- cpm[_ni*VWN + 7][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7][_mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
#elif VWN == 16
- cpm[_ni*VWN + 0 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0 ][_mi], aval, bpm[_ni].s0);
- cpm[_ni*VWN + 1 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1 ][_mi], aval, bpm[_ni].s1);
- cpm[_ni*VWN + 2 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2 ][_mi], aval, bpm[_ni].s2);
- cpm[_ni*VWN + 3 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3 ][_mi], aval, bpm[_ni].s3);
- cpm[_ni*VWN + 4 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4 ][_mi], aval, bpm[_ni].s4);
- cpm[_ni*VWN + 5 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5 ][_mi], aval, bpm[_ni].s5);
- cpm[_ni*VWN + 6 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6 ][_mi], aval, bpm[_ni].s6);
- cpm[_ni*VWN + 7 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7 ][_mi], aval, bpm[_ni].s7);
- cpm[_ni*VWN + 8 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 8 ][_mi], aval, bpm[_ni].s8);
- cpm[_ni*VWN + 9 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 9 ][_mi], aval, bpm[_ni].s9);
- cpm[_ni*VWN + 10][_mi] = MultiplyAddVector(cpm[_ni*VWN + 10][_mi], aval, bpm[_ni].sA);
- cpm[_ni*VWN + 11][_mi] = MultiplyAddVector(cpm[_ni*VWN + 11][_mi], aval, bpm[_ni].sB);
- cpm[_ni*VWN + 12][_mi] = MultiplyAddVector(cpm[_ni*VWN + 12][_mi], aval, bpm[_ni].sC);
- cpm[_ni*VWN + 13][_mi] = MultiplyAddVector(cpm[_ni*VWN + 13][_mi], aval, bpm[_ni].sD);
- cpm[_ni*VWN + 14][_mi] = MultiplyAddVector(cpm[_ni*VWN + 14][_mi], aval, bpm[_ni].sE);
- cpm[_ni*VWN + 15][_mi] = MultiplyAddVector(cpm[_ni*VWN + 15][_mi], aval, bpm[_ni].sF);
+ cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
+ cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
+ cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
+ cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
+ cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
+ cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
+ cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
+ cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
#endif
}
}
@@ -115,7 +115,7 @@ INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM],
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
-INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
+INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI*MWI/VWM], const int kSizeM,
const real alpha, const real beta) {
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
@@ -136,7 +136,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], cons
int index = idn*(kSizeM/VWM) + idm;
realM result;
- realM xval = cpm[_ni][_mi];
+ realM xval = cpm[_ni * (MWI/VWM) + _mi];
// The final multiplication with alpha (in case beta == 0)
if (IsZero(beta)) {