diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-18 21:32:56 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-18 21:32:56 +0200 |
commit | 489c5d76cfe95a97542dfeaa6d8b19cd9100919a (patch) | |
tree | 31a7082f5847f3bd21af1f2aa5a7d1eb68d188db /src/kernels/level3/xgemm_part2.opencl | |
parent | 7a3b695db70810595ae17d9d753c3b926aa738c0 (diff) |
Merged in latest changes from 0.7.1 release
Diffstat (limited to 'src/kernels/level3/xgemm_part2.opencl')
-rw-r--r-- | src/kernels/level3/xgemm_part2.opencl | 138 |
1 files changed, 71 insertions, 67 deletions
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl index a8c8ebf5..56ccdb96 100644 --- a/src/kernels/level3/xgemm_part2.opencl +++ b/src/kernels/level3/xgemm_part2.opencl @@ -69,42 +69,43 @@ inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], real for (int ni=0; ni<NWI/VWN; ++ni) { #pragma unroll for (int mi=0; mi<MWI/VWM; ++mi) { + const realM aval = apm[mi]; #if VWN == 1 - cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni]); + cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni]); #elif VWN == 2 - cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x); - cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y); + cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni].x); + cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], aval, bpm[ni].y); #elif VWN == 4 - cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x); - cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y); - cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].z); - cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].w); + cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni].x); + cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], aval, bpm[ni].y); + cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], aval, bpm[ni].z); + cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], aval, bpm[ni].w); #elif VWN == 8 - cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].s0); - cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].s1); - cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].s2); - cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].s3); - cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], apm[mi], bpm[ni].s4); - cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], apm[mi], bpm[ni].s5); - cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], apm[mi], bpm[ni].s6); - cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], apm[mi], bpm[ni].s7); + cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni].s0); + cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], aval, bpm[ni].s1); + cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], aval, bpm[ni].s2); + cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], aval, bpm[ni].s3); + cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], aval, bpm[ni].s4); + cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], aval, bpm[ni].s5); + cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], aval, bpm[ni].s6); + cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], aval, bpm[ni].s7); #elif VWN == 16 - cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], apm[mi], bpm[ni].s0); - cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], apm[mi], bpm[ni].s1); - cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], apm[mi], bpm[ni].s2); - cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], apm[mi], bpm[ni].s3); - cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], apm[mi], bpm[ni].s4); - cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], apm[mi], bpm[ni].s5); - cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], apm[mi], bpm[ni].s6); - cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], apm[mi], bpm[ni].s7); - cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], apm[mi], bpm[ni].s8); - cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], apm[mi], bpm[ni].s9); - cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], apm[mi], bpm[ni].sA); - cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], apm[mi], bpm[ni].sB); - cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], apm[mi], bpm[ni].sC); - cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], apm[mi], bpm[ni].sD); - cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], apm[mi], bpm[ni].sE); - cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], apm[mi], bpm[ni].sF); + cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], aval, bpm[ni].s0); + cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], aval, bpm[ni].s1); + cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], aval, bpm[ni].s2); + cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], aval, bpm[ni].s3); + cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], aval, bpm[ni].s4); + cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], aval, bpm[ni].s5); + cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], aval, bpm[ni].s6); + cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], aval, bpm[ni].s7); + cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], aval, bpm[ni].s8); + cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], aval, bpm[ni].s9); + cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], aval, bpm[ni].sA); + cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], aval, bpm[ni].sB); + cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], aval, bpm[ni].sC); + cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], aval, bpm[ni].sD); + cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], aval, bpm[ni].sE); + cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], aval, bpm[ni].sF); #endif } } @@ -130,49 +131,52 @@ inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int #elif STRN == 1 int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC; #endif - int idm = mg + get_group_id(0)*(MWG/VWM); - int idn = ng + get_group_id(1)*NWG; + int idm = mg + GetGroupID0() * (MWG/VWM); + int idn = ng + GetGroupID1() * NWG; // The final multiplication with alpha and the addition with beta*C int index = idn*(kSizeM/VWM) + idm; - realM cval = cgm[index]; + realM result; + realM xval = cpm[ni][mi]; + realM yval = cgm[index]; #if VWM == 1 - AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval); + AXPBY(result, alpha, xval, beta, yval); #elif VWM == 2 - AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); - AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); + AXPBY(result.x, alpha, xval.x, beta, yval.x); + AXPBY(result.y, alpha, xval.y, beta, yval.y); #elif VWM == 4 - AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); - AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); - AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z); - AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w); + AXPBY(result.x, alpha, xval.x, beta, yval.x); + AXPBY(result.y, alpha, xval.y, beta, yval.y); + AXPBY(result.z, alpha, xval.z, beta, yval.z); + AXPBY(result.w, alpha, xval.w, beta, yval.w); #elif VWM == 8 - AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); - AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); - AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); - AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); - AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); - AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); - AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); - AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); + AXPBY(result.s0, alpha, xval.s0, beta, yval.s0); + AXPBY(result.s1, alpha, xval.s1, beta, yval.s1); + AXPBY(result.s2, alpha, xval.s2, beta, yval.s2); + AXPBY(result.s3, alpha, xval.s3, beta, yval.s3); + AXPBY(result.s4, alpha, xval.s4, beta, yval.s4); + AXPBY(result.s5, alpha, xval.s5, beta, yval.s5); + AXPBY(result.s6, alpha, xval.s6, beta, yval.s6); + AXPBY(result.s7, alpha, xval.s7, beta, yval.s7); #elif VWM == 16 - AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); - AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); - AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); - AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); - AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); - AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); - AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); - AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); - AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8); - AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9); - AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA); - AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB); - AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC); - AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD); - AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE); - AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF); + AXPBY(result.s0, alpha, xval.s0, beta, yval.s0); + AXPBY(result.s1, alpha, xval.s1, beta, yval.s1); + AXPBY(result.s2, alpha, xval.s2, beta, yval.s2); + AXPBY(result.s3, alpha, xval.s3, beta, yval.s3); + AXPBY(result.s4, alpha, xval.s4, beta, yval.s4); + AXPBY(result.s5, alpha, xval.s5, beta, yval.s5); + AXPBY(result.s6, alpha, xval.s6, beta, yval.s6); + AXPBY(result.s7, alpha, xval.s7, beta, yval.s7); + AXPBY(result.s8, alpha, xval.s8, beta, yval.s8); + AXPBY(result.s9, alpha, xval.s9, beta, yval.s9); + AXPBY(result.sA, alpha, xval.sA, beta, yval.sA); + AXPBY(result.sB, alpha, xval.sB, beta, yval.sB); + AXPBY(result.sC, alpha, xval.sC, beta, yval.sC); + AXPBY(result.sD, alpha, xval.sD, beta, yval.sD); + AXPBY(result.sE, alpha, xval.sE, beta, yval.sE); + AXPBY(result.sF, alpha, xval.sF, beta, yval.sF); #endif + cgm[index] = result; } } } @@ -272,7 +276,7 @@ __kernel void XgemmUpper(const int kSizeN, const int kSizeK, const real beta = arg_beta[0]; // Skip these threads if they do not contain threads contributing to the upper-triangle - if (get_group_id(1)*NWG < get_group_id(0)*MWG) { + if (GetGroupID1()*NWG < GetGroupID0()*MWG) { return; } @@ -312,7 +316,7 @@ __kernel void XgemmLower(const int kSizeN, const int kSizeK, const real beta = arg_beta[0]; // Skip these threads if they do not contain threads contributing to the lower-triangle - if (get_group_id(1)*NWG > get_group_id(0)*MWG) { + if (GetGroupID1()*NWG > GetGroupID0()*MWG) { return; } |