diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-07 22:05:29 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-07 22:05:29 +0100 |
commit | 540896476d62ce37e7a939d185c15dc930b8a343 (patch) | |
tree | f9799153ab3fccebc5c3b3a9aa2b1c2db46e47c2 /src/kernels/level3/xgemm_part2.opencl | |
parent | 0f9637bbac6248a381d7012d7224331d3d394efb (diff) |
Added register promotion to the main GEMM kernel
Diffstat (limited to 'src/kernels/level3/xgemm_part2.opencl')
-rw-r--r-- | src/kernels/level3/xgemm_part2.opencl | 68 |
1 files changed, 34 insertions, 34 deletions
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl index a5507458..88100e96 100644 --- a/src/kernels/level3/xgemm_part2.opencl +++ b/src/kernels/level3/xgemm_part2.opencl @@ -64,48 +64,48 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva } // Performs the actual computation: Cpm += Apm * Bpm -INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) { +INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI*MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) { #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { const realM aval = apm[_mi]; #if VWN == 1 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni]); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); #elif VWN == 2 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x); - cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); #elif VWN == 4 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x); - cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y); - cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].z); - cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].w); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); #elif VWN == 8 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].s0); - cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].s1); - cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].s2); - cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].s3); - cpm[_ni*VWN + 4][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4][_mi], aval, bpm[_ni].s4); - cpm[_ni*VWN + 5][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5][_mi], aval, bpm[_ni].s5); - cpm[_ni*VWN + 6][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6][_mi], aval, bpm[_ni].s6); - cpm[_ni*VWN + 7][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7][_mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); #elif VWN == 16 - cpm[_ni*VWN + 0 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0 ][_mi], aval, bpm[_ni].s0); - cpm[_ni*VWN + 1 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1 ][_mi], aval, bpm[_ni].s1); - cpm[_ni*VWN + 2 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2 ][_mi], aval, bpm[_ni].s2); - cpm[_ni*VWN + 3 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3 ][_mi], aval, bpm[_ni].s3); - cpm[_ni*VWN + 4 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4 ][_mi], aval, bpm[_ni].s4); - cpm[_ni*VWN + 5 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5 ][_mi], aval, bpm[_ni].s5); - cpm[_ni*VWN + 6 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6 ][_mi], aval, bpm[_ni].s6); - cpm[_ni*VWN + 7 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7 ][_mi], aval, bpm[_ni].s7); - cpm[_ni*VWN + 8 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 8 ][_mi], aval, bpm[_ni].s8); - cpm[_ni*VWN + 9 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 9 ][_mi], aval, bpm[_ni].s9); - cpm[_ni*VWN + 10][_mi] = MultiplyAddVector(cpm[_ni*VWN + 10][_mi], aval, bpm[_ni].sA); - cpm[_ni*VWN + 11][_mi] = MultiplyAddVector(cpm[_ni*VWN + 11][_mi], aval, bpm[_ni].sB); - cpm[_ni*VWN + 12][_mi] = MultiplyAddVector(cpm[_ni*VWN + 12][_mi], aval, bpm[_ni].sC); - cpm[_ni*VWN + 13][_mi] = MultiplyAddVector(cpm[_ni*VWN + 13][_mi], aval, bpm[_ni].sD); - cpm[_ni*VWN + 14][_mi] = MultiplyAddVector(cpm[_ni*VWN + 14][_mi], aval, bpm[_ni].sE); - cpm[_ni*VWN + 15][_mi] = MultiplyAddVector(cpm[_ni*VWN + 15][_mi], aval, bpm[_ni].sF); + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); #endif } } @@ -115,7 +115,7 @@ INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], // Merges the results in Cpm with the global array in Cgm. This also performs the multiplication // with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm -INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM, +INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI*MWI/VWM], const int kSizeM, const real alpha, const real beta) { #pragma unroll for (int _ni = 0; _ni < NWI; _ni += 1) { @@ -136,7 +136,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], cons int index = idn*(kSizeM/VWM) + idm; realM result; - realM xval = cpm[_ni][_mi]; + realM xval = cpm[_ni * (MWI/VWM) + _mi]; // The final multiplication with alpha (in case beta == 0) if (IsZero(beta)) { |