summaryrefslogtreecommitdiff
path: root/src/kernels
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-18 21:18:07 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-18 21:18:07 +0200
commit181eb20bbf15cf11baaf6112b6965050c49dd543 (patch)
tree508d503f735f9edf4aed939dbc20ef66ffc18b55 /src/kernels
parentd91356a6b7f13fa1acd4395b217bb8a273ff8f70 (diff)
parent9a061528eb006f6a59b8bdc12c0e802bd28941cf (diff)
Merge pull request #60 from CNugteren/development
Update to version 0.7.1
Diffstat (limited to 'src/kernels')
-rw-r--r--src/kernels/common.opencl26
-rw-r--r--src/kernels/level3/xgemm_part1.opencl8
-rw-r--r--src/kernels/level3/xgemm_part2.opencl138
3 files changed, 101 insertions, 71 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index d401744d..b9e52e17 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -176,6 +176,32 @@ R"(
// =================================================================================================
+// Shuffled workgroup indices to avoid partition camping, see below. For specific devices, this is
+// enabled (see src/routine.cc).
+#ifndef USE_STAGGERED_INDICES
+ #define USE_STAGGERED_INDICES 0
+#endif
+
+// Staggered/shuffled group indices to avoid partition camping (AMD GPUs). Formula's are taken from:
+// http://docs.nvidia.com/cuda/samples/6_Advanced/transpose/doc/MatrixTranspose.pdf
+// More details: https://github.com/CNugteren/CLBlast/issues/53
+#if USE_STAGGERED_INDICES == 1
+ inline size_t GetGroupIDFlat() {
+ return get_group_id(0) + get_num_groups(0) * get_group_id(1);
+ }
+ inline size_t GetGroupID1() {
+ return (GetGroupIDFlat()) % get_num_groups(1);
+ }
+ inline size_t GetGroupID0() {
+ return ((GetGroupIDFlat() / get_num_groups(1)) + GetGroupID1()) % get_num_groups(0);
+ }
+#else
+ inline size_t GetGroupID1() { return get_group_id(1); }
+ inline size_t GetGroupID0() { return get_group_id(0); }
+#endif
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
index 4cb0585b..a2a555de 100644
--- a/src/kernels/level3/xgemm_part1.opencl
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -199,7 +199,7 @@ inline void GlobalToLocalA(const __global realM* restrict agm, __local realM* al
// Computes the indices for the global memory
int kg = kia + la1*KWA;
- int idm = mg + get_group_id(0)*(MWG/VWM);
+ int idm = mg + GetGroupID0() * (MWG/VWM);
int idk = kg + kwg;
// Loads the data from global memory (not transposed) into the local memory
@@ -229,7 +229,7 @@ inline void GlobalToLocalB(const __global realN* restrict bgm, __local realN* bl
// Computes the indices for the global memory
int kg = kib + lb1*KWB;
- int idn = ng + get_group_id(1)*(NWG/VWN);
+ int idn = ng + GetGroupID1() * (NWG/VWN);
int idk = kg + kwg;
// Loads the data from global memory (transposed) into the local memory
@@ -257,7 +257,7 @@ inline void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/V
#endif
// Computes the indices for the global memory
- int idm = mg + get_group_id(0)*(MWG/VWM);
+ int idm = mg + GetGroupID0() * (MWG/VWM);
// Loads the data from global memory (not transposed) and stores into registers
apm[mi] = agm[idk*(kSizeM/VWM) + idm];
@@ -280,7 +280,7 @@ inline void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/V
#endif
// Computes the indices for the global memory
- int idn = ng + get_group_id(1)*(NWG/VWN);
+ int idn = ng + GetGroupID1() * (NWG/VWN);
// Loads the data from global memory (transposed) and stores into registers
bpm[ni] = bgm[idk*(kSizeN/VWN) + idn];
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index c0760db6..599e01d5 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -69,42 +69,43 @@ inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], real
for (int ni=0; ni<NWI/VWN; ++ni) {
#pragma unroll
for (int mi=0; mi<MWI/VWM; ++mi) {
+ const realM aval = apm[mi];
#if VWN == 1
- cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni]);
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni]);
#elif VWN == 2
- cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x);
- cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y);
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni].x);
+ cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], aval, bpm[ni].y);
#elif VWN == 4
- cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x);
- cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y);
- cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].z);
- cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].w);
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni].x);
+ cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], aval, bpm[ni].y);
+ cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], aval, bpm[ni].z);
+ cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], aval, bpm[ni].w);
#elif VWN == 8
- cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].s0);
- cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].s1);
- cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].s2);
- cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].s3);
- cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], apm[mi], bpm[ni].s4);
- cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], apm[mi], bpm[ni].s5);
- cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], apm[mi], bpm[ni].s6);
- cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], apm[mi], bpm[ni].s7);
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], aval, bpm[ni].s0);
+ cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], aval, bpm[ni].s1);
+ cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], aval, bpm[ni].s2);
+ cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], aval, bpm[ni].s3);
+ cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], aval, bpm[ni].s4);
+ cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], aval, bpm[ni].s5);
+ cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], aval, bpm[ni].s6);
+ cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], aval, bpm[ni].s7);
#elif VWN == 16
- cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], apm[mi], bpm[ni].s0);
- cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], apm[mi], bpm[ni].s1);
- cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], apm[mi], bpm[ni].s2);
- cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], apm[mi], bpm[ni].s3);
- cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], apm[mi], bpm[ni].s4);
- cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], apm[mi], bpm[ni].s5);
- cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], apm[mi], bpm[ni].s6);
- cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], apm[mi], bpm[ni].s7);
- cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], apm[mi], bpm[ni].s8);
- cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], apm[mi], bpm[ni].s9);
- cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], apm[mi], bpm[ni].sA);
- cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], apm[mi], bpm[ni].sB);
- cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], apm[mi], bpm[ni].sC);
- cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], apm[mi], bpm[ni].sD);
- cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], apm[mi], bpm[ni].sE);
- cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], apm[mi], bpm[ni].sF);
+ cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], aval, bpm[ni].s0);
+ cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], aval, bpm[ni].s1);
+ cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], aval, bpm[ni].s2);
+ cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], aval, bpm[ni].s3);
+ cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], aval, bpm[ni].s4);
+ cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], aval, bpm[ni].s5);
+ cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], aval, bpm[ni].s6);
+ cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], aval, bpm[ni].s7);
+ cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], aval, bpm[ni].s8);
+ cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], aval, bpm[ni].s9);
+ cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], aval, bpm[ni].sA);
+ cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], aval, bpm[ni].sB);
+ cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], aval, bpm[ni].sC);
+ cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], aval, bpm[ni].sD);
+ cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], aval, bpm[ni].sE);
+ cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], aval, bpm[ni].sF);
#endif
}
}
@@ -130,49 +131,52 @@ inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int
#elif STRN == 1
int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
#endif
- int idm = mg + get_group_id(0)*(MWG/VWM);
- int idn = ng + get_group_id(1)*NWG;
+ int idm = mg + GetGroupID0() * (MWG/VWM);
+ int idn = ng + GetGroupID1() * NWG;
// The final multiplication with alpha and the addition with beta*C
int index = idn*(kSizeM/VWM) + idm;
- realM cval = cgm[index];
+ realM result;
+ realM xval = cpm[ni][mi];
+ realM yval = cgm[index];
#if VWM == 1
- AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval);
+ AXPBY(result, alpha, xval, beta, yval);
#elif VWM == 2
- AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
- AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
#elif VWM == 4
- AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
- AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
- AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z);
- AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w);
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ AXPBY(result.z, alpha, xval.z, beta, yval.z);
+ AXPBY(result.w, alpha, xval.w, beta, yval.w);
#elif VWM == 8
- AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
- AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
- AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
- AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
- AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
- AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
- AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
- AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
#elif VWM == 16
- AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
- AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
- AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
- AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
- AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
- AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
- AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
- AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
- AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8);
- AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9);
- AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA);
- AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB);
- AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC);
- AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD);
- AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE);
- AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF);
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
+ AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
+ AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
+ AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
+ AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
+ AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
+ AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
+ AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
#endif
+ cgm[index] = result;
}
}
}
@@ -269,7 +273,7 @@ __kernel void XgemmUpper(const int kSizeN, const int kSizeK,
__global realM* cgm) {
// Skip these threads if they do not contain threads contributing to the upper-triangle
- if (get_group_id(1)*NWG < get_group_id(0)*MWG) {
+ if (GetGroupID1()*NWG < GetGroupID0()*MWG) {
return;
}
@@ -306,7 +310,7 @@ __kernel void XgemmLower(const int kSizeN, const int kSizeK,
__global realM* cgm) {
// Skip these threads if they do not contain threads contributing to the lower-triangle
- if (get_group_id(1)*NWG > get_group_id(0)*MWG) {
+ if (GetGroupID1()*NWG > GetGroupID0()*MWG) {
return;
}