summaryrefslogtreecommitdiff
path: root/src/kernels
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-06-23 17:58:51 +0200
committerCNugteren <web@cedricnugteren.nl>2015-06-23 17:58:51 +0200
commit9fc38cdf5ed44ef41cf3d6cf9e7c32585447c042 (patch)
tree6f3163b61d780f6b25e21cc34dda68e76a2de637 /src/kernels
parent0a3831e6d1eb437a9ef9ac7570f9a554b2c35edb (diff)
Added a lower/upper triangular version of the GEMM kernel
Diffstat (limited to 'src/kernels')
-rw-r--r--src/kernels/xgemm.opencl371
1 files changed, 240 insertions, 131 deletions
diff --git a/src/kernels/xgemm.opencl b/src/kernels/xgemm.opencl
index a4f45e90..4c7ae064 100644
--- a/src/kernels/xgemm.opencl
+++ b/src/kernels/xgemm.opencl
@@ -127,6 +127,55 @@ R"(
// =================================================================================================
+// Initializes the accumulation registers to zero
+inline void InitAccRegisters(realM cpm[NWI][MWI/VWM]) {
+ #pragma unroll
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+ #pragma unroll
+ for (int ni=0; ni<NWI; ++ni) {
+ #if VWM == 1
+ SetToZero(cpm[ni][mi]);
+ #elif VWM == 2
+ SetToZero(cpm[ni][mi].x);
+ SetToZero(cpm[ni][mi].y);
+ #elif VWM == 4
+ SetToZero(cpm[ni][mi].x);
+ SetToZero(cpm[ni][mi].y);
+ SetToZero(cpm[ni][mi].z);
+ SetToZero(cpm[ni][mi].w);
+ #elif VWM == 8
+ SetToZero(cpm[ni][mi].s0);
+ SetToZero(cpm[ni][mi].s1);
+ SetToZero(cpm[ni][mi].s2);
+ SetToZero(cpm[ni][mi].s3);
+ SetToZero(cpm[ni][mi].s4);
+ SetToZero(cpm[ni][mi].s5);
+ SetToZero(cpm[ni][mi].s6);
+ SetToZero(cpm[ni][mi].s7);
+ #elif VWM == 16
+ SetToZero(cpm[ni][mi].s0);
+ SetToZero(cpm[ni][mi].s1);
+ SetToZero(cpm[ni][mi].s2);
+ SetToZero(cpm[ni][mi].s3);
+ SetToZero(cpm[ni][mi].s4);
+ SetToZero(cpm[ni][mi].s5);
+ SetToZero(cpm[ni][mi].s6);
+ SetToZero(cpm[ni][mi].s7);
+ SetToZero(cpm[ni][mi].s8);
+ SetToZero(cpm[ni][mi].s9);
+ SetToZero(cpm[ni][mi].sA);
+ SetToZero(cpm[ni][mi].sB);
+ SetToZero(cpm[ni][mi].sC);
+ SetToZero(cpm[ni][mi].sD);
+ SetToZero(cpm[ni][mi].sE);
+ SetToZero(cpm[ni][mi].sF);
+ #endif
+ }
+ }
+}
+
+// =================================================================================================
+
// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
// caching the A input matrix.
#if SA == 1
@@ -272,71 +321,6 @@ inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg
// =================================================================================================
-// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
-// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
-inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
- const real alpha, const real beta) {
- #pragma unroll
- for (int ni=0; ni<NWI; ++ni) {
- #pragma unroll
- for (int mi=0; mi<MWI/VWM; ++mi) {
- #if STRM == 0
- int mg = mi + get_local_id(0)*(MWI/VWM);
- #elif STRM == 1
- int mg = get_local_id(0) + mi*MDIMC;
- #endif
- #if STRN == 0
- int ng = ni + get_local_id(1)*NWI;
- #elif STRN == 1
- int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
- #endif
- int idm = mg + get_group_id(0)*(MWG/VWM);
- int idn = ng + get_group_id(1)*NWG;
- int index = idn*(kSizeM/VWM) + idm;
- realM cval = cgm[index];
- #if VWM == 1
- AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval);
- #elif VWM == 2
- AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
- AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
- #elif VWM == 4
- AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
- AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
- AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z);
- AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w);
- #elif VWM == 8
- AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
- AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
- AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
- AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
- AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
- AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
- AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
- AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
- #elif VWM == 16
- AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
- AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
- AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
- AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
- AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
- AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
- AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
- AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
- AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8);
- AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9);
- AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA);
- AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB);
- AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC);
- AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD);
- AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE);
- AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF);
- #endif
- }
- }
-}
-
-// =================================================================================================
-
// The vectorised multiply-add function
inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) {
#if USE_VECTOR_MAD == 1
@@ -432,77 +416,97 @@ inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], real
// =================================================================================================
-// Main entry of the kernel. This function contains the basic skeleton, the functionality is
-// provided by the inlined functions above
-__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
-__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
- const real alpha, const real beta,
- const __global realM* restrict agm,
- const __global realN* restrict bgm,
- __global realM* cgm) {
-
- // Combined thread identifier
- #if SA == 1 || SB == 1
- volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
- #endif
-
- // Allocates workgroup-private memory (local memory)
- #if SA == 1
- __local realM alm[KWG * MWG/VWM];
- #endif
- #if SB == 1
- __local realN blm[KWG * NWG/VWN];
- #endif
-
- // Allocates workitem-private memory (registers)
- realM apm[MWI/VWM];
- realN bpm[NWI/VWN];
- realM cpm[NWI][MWI/VWM];
-
- // Initializes the accumulation registers
+// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
+// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
+inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
+ const real alpha, const real beta) {
#pragma unroll
- for (int mi=0; mi<MWI/VWM; ++mi) {
+ for (int ni=0; ni<NWI; ++ni) {
#pragma unroll
- for (int ni=0; ni<NWI; ++ni) {
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+ #if STRM == 0
+ int mg = mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + mi*MDIMC;
+ #endif
+ #if STRN == 0
+ int ng = ni + get_local_id(1)*NWI;
+ #elif STRN == 1
+ int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
+ #endif
+ int idm = mg + get_group_id(0)*(MWG/VWM);
+ int idn = ng + get_group_id(1)*NWG;
+
+ // The final multiplication with alpha and the addition with beta*C
+ int index = idn*(kSizeM/VWM) + idm;
+ realM cval = cgm[index];
#if VWM == 1
- SetToZero(cpm[ni][mi]);
+ AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval);
#elif VWM == 2
- SetToZero(cpm[ni][mi].x);
- SetToZero(cpm[ni][mi].y);
+ AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
+ AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
#elif VWM == 4
- SetToZero(cpm[ni][mi].x);
- SetToZero(cpm[ni][mi].y);
- SetToZero(cpm[ni][mi].z);
- SetToZero(cpm[ni][mi].w);
+ AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
+ AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
+ AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z);
+ AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w);
#elif VWM == 8
- SetToZero(cpm[ni][mi].s0);
- SetToZero(cpm[ni][mi].s1);
- SetToZero(cpm[ni][mi].s2);
- SetToZero(cpm[ni][mi].s3);
- SetToZero(cpm[ni][mi].s4);
- SetToZero(cpm[ni][mi].s5);
- SetToZero(cpm[ni][mi].s6);
- SetToZero(cpm[ni][mi].s7);
+ AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
+ AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
+ AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
+ AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
+ AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
+ AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
+ AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
+ AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
#elif VWM == 16
- SetToZero(cpm[ni][mi].s0);
- SetToZero(cpm[ni][mi].s1);
- SetToZero(cpm[ni][mi].s2);
- SetToZero(cpm[ni][mi].s3);
- SetToZero(cpm[ni][mi].s4);
- SetToZero(cpm[ni][mi].s5);
- SetToZero(cpm[ni][mi].s6);
- SetToZero(cpm[ni][mi].s7);
- SetToZero(cpm[ni][mi].s8);
- SetToZero(cpm[ni][mi].s9);
- SetToZero(cpm[ni][mi].sA);
- SetToZero(cpm[ni][mi].sB);
- SetToZero(cpm[ni][mi].sC);
- SetToZero(cpm[ni][mi].sD);
- SetToZero(cpm[ni][mi].sE);
- SetToZero(cpm[ni][mi].sF);
+ AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
+ AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
+ AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
+ AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
+ AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
+ AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
+ AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
+ AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
+ AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8);
+ AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9);
+ AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA);
+ AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB);
+ AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC);
+ AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD);
+ AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE);
+ AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF);
#endif
}
}
+}
+
+// =================================================================================================
+
+// Main body of the matrix-multiplication algorithm. It calls the (inlined) functions above.
+inline void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
+ const __global realM* restrict agm, const __global realN* restrict bgm,
+ __global realM* cgm, realM cpm[NWI][MWI/VWM],
+ #if SA == 1 && SB == 1
+ __local realM* alm, __local realN* blm
+ #elif SA == 1
+ __local realM* alm
+ #elif SB == 1
+ __local realN* blm
+ #endif
+ ) {
+
+ // Allocates workitem-private memory (registers)
+ realM apm[MWI/VWM];
+ realN bpm[NWI/VWN];
+
+ // Combined thread identifier (volatile to disable caching)
+ #if SA == 1 || SB == 1
+ volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
+ #endif
+
+ // Initializes the accumulation registers
+ InitAccRegisters(cpm);
// Loops over all workgroup tiles
for (int kwg=0; kwg<kSizeK; kwg+=KWG) {
@@ -515,8 +519,6 @@ __kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
#if SB == 1
GlobalToLocalB(bgm, blm, kSizeN, tid, kwg);
#endif
-
- // Synchronizes all threads in a workgroup
#if SA == 1 || SB == 1
barrier(CLK_LOCAL_MEM_FENCE);
#endif
@@ -552,19 +554,126 @@ __kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
MultiplyAccumulate(cpm, apm, bpm);
}
}
-
- // Synchronizes all threads in a workgroup
#if SA == 1 || SB == 1
barrier(CLK_LOCAL_MEM_FENCE);
#endif
}
+}
- // Stores an MWG * NWG tile of results and perform the multiplication with alpha and beta
+// =================================================================================================
+
+// Main entry point of the kernel. This is the regular full version.
+__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real alpha, const real beta,
+ const __global realM* restrict agm,
+ const __global realN* restrict bgm,
+ __global realM* cgm) {
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in register memory
+ realM cpm[NWI][MWI/VWM];
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ #elif SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ #else
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ #endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
// =================================================================================================
+// Main entry point of the kernel. This is the upper-triangular version.
+__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+__kernel void XgemmUpper(const int kSizeN, const int kSizeK,
+ const real alpha, const real beta,
+ const __global realM* restrict agm,
+ const __global realN* restrict bgm,
+ __global realM* cgm) {
+
+ // Skip these threads if they do not contain threads contributing to the upper-triangle
+ if (get_group_id(1)*NWG < get_group_id(0)*MWG) {
+ return;
+ }
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in register memory
+ realM cpm[NWI][MWI/VWM];
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ #elif SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ #else
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ #endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeN, alpha, beta);
+}
+
+// =================================================================================================
+
+// Main entry point of the kernel. This is the lower-triangular version.
+__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+__kernel void XgemmLower(const int kSizeN, const int kSizeK,
+ const real alpha, const real beta,
+ const __global realM* restrict agm,
+ const __global realN* restrict bgm,
+ __global realM* cgm) {
+
+ // Skip these threads if they do not contain threads contributing to the lower-triangle
+ if (get_group_id(1)*NWG > get_group_id(0)*MWG) {
+ return;
+ }
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in register memory
+ realM cpm[NWI][MWI/VWM];
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ #elif SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ #else
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ #endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeN, alpha, beta);
+}
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)";