summaryrefslogtreecommitdiff
path: root/src/kernels/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-09-25 11:38:35 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-09-25 11:38:35 +0200
commit140dc12854dd9521c1420ccba7eb9fb0d50e054e (patch)
tree62384dd7764bff0331e55d47cc5c44953174be55 /src/kernels/level3
parent6aa652d6ea2389744195ae5cd19321325b2d71aa (diff)
Added a first version of the direct version of GEMM with local memory
Diffstat (limited to 'src/kernels/level3')
-rw-r--r--src/kernels/level3/xgemm_direct.opencl198
1 files changed, 194 insertions, 4 deletions
diff --git a/src/kernels/level3/xgemm_direct.opencl b/src/kernels/level3/xgemm_direct.opencl
index a5e8ca3d..fb5972ba 100644
--- a/src/kernels/level3/xgemm_direct.opencl
+++ b/src/kernels/level3/xgemm_direct.opencl
@@ -18,6 +18,164 @@ R"(
// =================================================================================================
+// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
+// caching the A input matrix.
+inline void GlobalToLocalDirectA(const __global realM* restrict agm, __local real* alm,
+ const int a_ld, const int a_offset, const int tid, const int kwg,
+ const int a_transpose, const int a_conjugate) {
+ const int la0 = tid % MDIMA;
+ const int la1 = tid / MDIMA;
+ #pragma unroll
+ for (int mia=0; mia<MWA/VWM; ++mia) {
+ #pragma unroll
+ for (int kia=0; kia<KWA; ++kia) {
+
+ // Computes the indices for the global memory
+ int mg = mia + la0*(MWA/VWM);
+ int kg = kia + la1*KWA;
+ int idm = (a_transpose) ? mg + kwg/VWM : mg + GetGroupID0()*(MWG/VWM);
+ int idk = (a_transpose) ? kg + GetGroupID0()*MWG : kg + kwg;
+
+ // Loads the data from global memory into the local memory
+ const realM avec = agm[idk*(a_ld/VWM) + idm + a_offset];
+ #if VWM == 1
+ alm[kg*MWG + mg] = avec;
+ #elif VWM == 2
+ alm[kg*MWG + mg*VWM + 0] = avec.x;
+ alm[kg*MWG + mg*VWM + 1] = avec.y;
+ #elif VWM == 4
+ alm[kg*MWG + mg*VWM + 0] = avec.x;
+ alm[kg*MWG + mg*VWM + 1] = avec.y;
+ alm[kg*MWG + mg*VWM + 2] = avec.z;
+ alm[kg*MWG + mg*VWM + 3] = avec.w;
+ #elif VWM == 8
+ alm[kg*MWG + mg*VWM + 0] = avec.s0;
+ alm[kg*MWG + mg*VWM + 1] = avec.s1;
+ alm[kg*MWG + mg*VWM + 2] = avec.s2;
+ alm[kg*MWG + mg*VWM + 3] = avec.s3;
+ alm[kg*MWG + mg*VWM + 4] = avec.s4;
+ alm[kg*MWG + mg*VWM + 5] = avec.s5;
+ alm[kg*MWG + mg*VWM + 6] = avec.s6;
+ alm[kg*MWG + mg*VWM + 7] = avec.s7;
+ #elif VWM == 16
+ alm[kg*MWG + mg*VWM + 0] = avec.s0;
+ alm[kg*MWG + mg*VWM + 1] = avec.s1;
+ alm[kg*MWG + mg*VWM + 2] = avec.s2;
+ alm[kg*MWG + mg*VWM + 3] = avec.s3;
+ alm[kg*MWG + mg*VWM + 4] = avec.s4;
+ alm[kg*MWG + mg*VWM + 5] = avec.s5;
+ alm[kg*MWG + mg*VWM + 6] = avec.s6;
+ alm[kg*MWG + mg*VWM + 7] = avec.s7;
+ alm[kg*MWG + mg*VWM + 8] = avec.s8;
+ alm[kg*MWG + mg*VWM + 9] = avec.s9;
+ alm[kg*MWG + mg*VWM + 10] = avec.sA;
+ alm[kg*MWG + mg*VWM + 11] = avec.sB;
+ alm[kg*MWG + mg*VWM + 12] = avec.sC;
+ alm[kg*MWG + mg*VWM + 13] = avec.sD;
+ alm[kg*MWG + mg*VWM + 14] = avec.sE;
+ alm[kg*MWG + mg*VWM + 15] = avec.sF;
+ #endif
+ if (a_conjugate) {
+ for (int vm=0; vm<VWM; ++vm) {
+ COMPLEX_CONJUGATE(alm[kg*MWG + mg*VWM + vm]);
+ }
+ }
+ }
+ }
+}
+
+// Same as above, but now for the B input matrix
+inline void GlobalToLocalDirectB(const __global realN* restrict bgm, __local real* blm,
+ const int b_ld, const int b_offset, const int tid, const int kwg,
+ const int b_transpose, const int b_conjugate) {
+ const int lb0 = tid % NDIMB;
+ const int lb1 = tid / NDIMB;
+ #pragma unroll
+ for (int kib=0; kib<KWB; ++kib) {
+ #pragma unroll
+ for (int nib=0; nib<NWB/VWN; ++nib) {
+
+ // Computes the indices for the global memory
+ int ng = nib + lb0*(NWB/VWN);
+ int kg = kib + lb1*KWB;
+ int idn = (b_transpose) ? ng + kwg/VWN : ng + GetGroupID1()*(NWG/VWN);
+ int idk = (b_transpose) ? kg + GetGroupID1()*NWG : kg + kwg;
+
+ // Loads the data from global memory into the local memory
+ const realM bvec = bgm[idk*(b_ld/VWN) + idn + b_offset];
+ #if VWN == 1
+ blm[kg*NWG + ng] = bvec;
+ #elif VWN == 2
+ blm[kg*NWG + ng*VWN + 0] = bvec.x;
+ blm[kg*NWG + ng*VWN + 1] = bvec.y;
+ #elif VWN == 4
+ blm[kg*NWG + ng*VWN + 0] = bvec.x;
+ blm[kg*NWG + ng*VWN + 1] = bvec.y;
+ blm[kg*NWG + ng*VWN + 2] = bvec.z;
+ blm[kg*NWG + ng*VWN + 3] = bvec.w;
+ #elif VWN == 8
+ blm[kg*NWG + ng*VWN + 0] = bvec.s0;
+ blm[kg*NWG + ng*VWN + 1] = bvec.s1;
+ blm[kg*NWG + ng*VWN + 2] = bvec.s2;
+ blm[kg*NWG + ng*VWN + 3] = bvec.s3;
+ blm[kg*NWG + ng*VWN + 4] = bvec.s4;
+ blm[kg*NWG + ng*VWN + 5] = bvec.s5;
+ blm[kg*NWG + ng*VWN + 6] = bvec.s6;
+ blm[kg*NWG + ng*VWN + 7] = bvec.s7;
+ #elif VWN == 16
+ blm[kg*NWG + ng*VWN + 0] = bvec.s0;
+ blm[kg*NWG + ng*VWN + 1] = bvec.s1;
+ blm[kg*NWG + ng*VWN + 2] = bvec.s2;
+ blm[kg*NWG + ng*VWN + 3] = bvec.s3;
+ blm[kg*NWG + ng*VWN + 4] = bvec.s4;
+ blm[kg*NWG + ng*VWN + 5] = bvec.s5;
+ blm[kg*NWG + ng*VWN + 6] = bvec.s6;
+ blm[kg*NWG + ng*VWN + 7] = bvec.s7;
+ blm[kg*NWG + ng*VWN + 8] = bvec.s8;
+ blm[kg*NWG + ng*VWN + 9] = bvec.s9;
+ blm[kg*NWG + ng*VWN + 10] = bvec.sA;
+ blm[kg*NWG + ng*VWN + 11] = bvec.sB;
+ blm[kg*NWG + ng*VWN + 12] = bvec.sC;
+ blm[kg*NWG + ng*VWN + 13] = bvec.sD;
+ blm[kg*NWG + ng*VWN + 14] = bvec.sE;
+ blm[kg*NWG + ng*VWN + 15] = bvec.sF;
+ #endif
+ if (b_conjugate) {
+ for (int vn=0; vn<VWN; ++vn) {
+ COMPLEX_CONJUGATE(blm[kg*NWG + ng*VWN + vn]);
+ }
+ }
+ }
+ }
+}
+
+// =================================================================================================
+
+// Caches on-chip local memory into per-thread private memory (registers). This function is specific
+// for caching the A input matrix.
+inline void LocalToPrivateDirectA(__local real* alm, real apm[MWI], const int kg,
+ const int a_transpose) {
+ #pragma unroll
+ for (int mi=0; mi<MWI; ++mi) {
+ const int mg = mi + get_local_id(0)*MWI;
+ const int index = (a_transpose) ? mg*KWG + kg : kg*MWG + mg;
+ apm[mi] = alm[index];
+ }
+}
+
+// Same as above, but now for the B input matrix
+inline void LocalToPrivateDirectB(__local real* blm, real bpm[NWI], const int kg,
+ const int b_transpose) {
+ #pragma unroll
+ for (int ni=0; ni<NWI; ++ni) {
+ const int ng = ni + get_local_id(1)*NWI;
+ const int index = (b_transpose) ? ng*KWG + kg : kg*NWG + ng;
+ bpm[ni] = blm[index];
+ }
+}
+
+// =================================================================================================
+
// Initializes the accumulation registers to zero
inline void InitAccRegistersDirect(real cpm[NWI][MWI]) {
#pragma unroll
@@ -28,6 +186,7 @@ inline void InitAccRegistersDirect(real cpm[NWI][MWI]) {
}
}
}
+
// =================================================================================================
// Performs the actual computation: Cpm += Apm * Bpm
@@ -88,6 +247,13 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
const __global real* restrict agms = (const __global real* restrict) agm;
const __global real* restrict bgms = (const __global real* restrict) bgm;
+ // Allocates workgroup-private memory (local memory)
+ __local real alm[KWG * MWG];
+ __local real blm[KWG * NWG];
+
+ // Combined thread identifier (volatile to disable caching)
+ volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
+
// Allocates workitem-private memory (registers)
real apm[MWI];
real bpm[NWI];
@@ -97,15 +263,39 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
InitAccRegistersDirect(cpm);
// The faster version of GEMM is not allowed on the (incomplete) borders. Therefore, this section
- // processes only the main parts: output blocks of NWI by MWI.
+ // processes only the main parts: output blocks of MWG by NWG.
const int idm = get_local_id(0) * MWI + GetGroupID0() * MWG;
const int idn = get_local_id(1) * NWI + GetGroupID1() * NWG;
- if ((idm < kSizeM - MWI) && (idn < kSizeN - NWI)) {
+ if ((idm < (kSizeM/MWG)*MWG) && (idn < (kSizeN/NWG)*NWG) &&
+ (a_ld % VWM == 0) && (b_ld % VWN == 0)) {
// Loops over all complete workgroup tiles
int kwg = 0;
- // TODO: Implement a faster version with local memory and vector loads
- // for (; kwg < kSizeK - KWG; kwg+=KWG) { }
+ for (; kwg < (kSizeK/KWG) * KWG; kwg+=KWG) {
+
+ // Loads data: off-chip --> local (matrix A and B)
+ GlobalToLocalDirectA(agm, alm, a_ld, a_offset, tid, kwg, a_transpose, a_conjugate);
+ GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, tid, kwg, b_transpose, b_conjugate);
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ // Loops over all workitem tiles, unrolled by a factor KWI
+ for (int pwi=0; pwi<KWG; pwi+=KWI) {
+ #pragma unroll
+ for (int pit=0; pit<KWI; ++pit) {
+ int kg = pwi + pit;
+
+ // Loads data: local --> private (matrix A)
+ LocalToPrivateDirectA(alm, apm, kg, a_transpose);
+
+ // Loads data: local --> private (matrix B)
+ LocalToPrivateDirectB(blm, bpm, kg, b_transpose);
+
+ // Performs the accumulation (Cpm += Apm * Bpm)
+ MultiplyAccumulateDirect(cpm, apm, bpm);
+ }
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ }
// Loop over the remaining part (incomplete tile in K-dimension)
for (; kwg < kSizeK; ++kwg) {