summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_part2.opencl
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-02-08 20:06:02 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-02-08 20:06:02 +0100
commitbf84463ab20f2f39071719fad9bd28a6bb13fc24 (patch)
treedf4a6fff31d178186bd36538da3705ccf3353eb3 /src/kernels/level3/xgemm_part2.opencl
parent38c56bbde2ed108d47bd058ba239725b3396475d (diff)
Separated the GEMM kernel in two parts to reduce string length for MSVC
Diffstat (limited to 'src/kernels/level3/xgemm_part2.opencl')
-rw-r--r--src/kernels/level3/xgemm_part2.opencl379
1 files changed, 379 insertions, 0 deletions
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
new file mode 100644
index 00000000..c0760db6
--- /dev/null
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -0,0 +1,379 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This is part 2 of 2 of the GEMM kernel. See part 1 for more information.
+//
+// =================================================================================================
+
+// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
+// literal). Comment-out this line for syntax-highlighting when developing.
+R"(
+
+// =================================================================================================
+
+// The vectorised multiply-add function
+inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) {
+ #if USE_VECTOR_MAD == 1
+ cvec += avec * bval;
+ #else
+ #if VWM == 1
+ MultiplyAdd(cvec, avec, bval);
+ #elif VWM == 2
+ MultiplyAdd(cvec.x , avec.x, bval);
+ MultiplyAdd(cvec.y , avec.y, bval);
+ #elif VWM == 4
+ MultiplyAdd(cvec.x , avec.x, bval);
+ MultiplyAdd(cvec.y , avec.y, bval);
+ MultiplyAdd(cvec.z , avec.z, bval);
+ MultiplyAdd(cvec.w , avec.w, bval);
+ #elif VWM == 8
+ MultiplyAdd(cvec.s0, avec.s0, bval);
+ MultiplyAdd(cvec.s1, avec.s1, bval);
+ MultiplyAdd(cvec.s2, avec.s2, bval);
+ MultiplyAdd(cvec.s3, avec.s3, bval);
+ MultiplyAdd(cvec.s4, avec.s4, bval);
+ MultiplyAdd(cvec.s5, avec.s5, bval);
+ MultiplyAdd(cvec.s6, avec.s6, bval);
+ MultiplyAdd(cvec.s7, avec.s7, bval);
+ #elif VWM == 16
+ MultiplyAdd(cvec.s0, avec.s0, bval);
+ MultiplyAdd(cvec.s1, avec.s1, bval);
+ MultiplyAdd(cvec.s2, avec.s2, bval);
+ MultiplyAdd(cvec.s3, avec.s3, bval);
+ MultiplyAdd(cvec.s4, avec.s4, bval);
+ MultiplyAdd(cvec.s5, avec.s5, bval);
+ MultiplyAdd(cvec.s6, avec.s6, bval);
+ MultiplyAdd(cvec.s7, avec.s7, bval);
+ MultiplyAdd(cvec.s8, avec.s8, bval);
+ MultiplyAdd(cvec.s9, avec.s9, bval);
+ MultiplyAdd(cvec.sA, avec.sA, bval);
+ MultiplyAdd(cvec.sB, avec.sB, bval);
+ MultiplyAdd(cvec.sC, avec.sC, bval);
+ MultiplyAdd(cvec.sD, avec.sD, bval);
+ MultiplyAdd(cvec.sE, avec.sE, bval);
+ MultiplyAdd(cvec.sF, avec.sF, bval);
+ #endif
+ #endif
+ return cvec;
+}
+
+// Performs the actual computation: Cpm += Apm * Bpm
+inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
+ #pragma unroll
+ for (int ni=0; ni<NWI/VWN; ++ni) {
+ #pragma unroll
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+ #if VWN == 1
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni]);
+ #elif VWN == 2
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x);
+ cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y);
+ #elif VWN == 4
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].x);
+ cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].y);
+ cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].z);
+ cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].w);
+ #elif VWN == 8
+ cpm[ni*VWN + 0][mi] = MultiplyAddVector(cpm[ni*VWN + 0][mi], apm[mi], bpm[ni].s0);
+ cpm[ni*VWN + 1][mi] = MultiplyAddVector(cpm[ni*VWN + 1][mi], apm[mi], bpm[ni].s1);
+ cpm[ni*VWN + 2][mi] = MultiplyAddVector(cpm[ni*VWN + 2][mi], apm[mi], bpm[ni].s2);
+ cpm[ni*VWN + 3][mi] = MultiplyAddVector(cpm[ni*VWN + 3][mi], apm[mi], bpm[ni].s3);
+ cpm[ni*VWN + 4][mi] = MultiplyAddVector(cpm[ni*VWN + 4][mi], apm[mi], bpm[ni].s4);
+ cpm[ni*VWN + 5][mi] = MultiplyAddVector(cpm[ni*VWN + 5][mi], apm[mi], bpm[ni].s5);
+ cpm[ni*VWN + 6][mi] = MultiplyAddVector(cpm[ni*VWN + 6][mi], apm[mi], bpm[ni].s6);
+ cpm[ni*VWN + 7][mi] = MultiplyAddVector(cpm[ni*VWN + 7][mi], apm[mi], bpm[ni].s7);
+ #elif VWN == 16
+ cpm[ni*VWN + 0 ][mi] = MultiplyAddVector(cpm[ni*VWN + 0 ][mi], apm[mi], bpm[ni].s0);
+ cpm[ni*VWN + 1 ][mi] = MultiplyAddVector(cpm[ni*VWN + 1 ][mi], apm[mi], bpm[ni].s1);
+ cpm[ni*VWN + 2 ][mi] = MultiplyAddVector(cpm[ni*VWN + 2 ][mi], apm[mi], bpm[ni].s2);
+ cpm[ni*VWN + 3 ][mi] = MultiplyAddVector(cpm[ni*VWN + 3 ][mi], apm[mi], bpm[ni].s3);
+ cpm[ni*VWN + 4 ][mi] = MultiplyAddVector(cpm[ni*VWN + 4 ][mi], apm[mi], bpm[ni].s4);
+ cpm[ni*VWN + 5 ][mi] = MultiplyAddVector(cpm[ni*VWN + 5 ][mi], apm[mi], bpm[ni].s5);
+ cpm[ni*VWN + 6 ][mi] = MultiplyAddVector(cpm[ni*VWN + 6 ][mi], apm[mi], bpm[ni].s6);
+ cpm[ni*VWN + 7 ][mi] = MultiplyAddVector(cpm[ni*VWN + 7 ][mi], apm[mi], bpm[ni].s7);
+ cpm[ni*VWN + 8 ][mi] = MultiplyAddVector(cpm[ni*VWN + 8 ][mi], apm[mi], bpm[ni].s8);
+ cpm[ni*VWN + 9 ][mi] = MultiplyAddVector(cpm[ni*VWN + 9 ][mi], apm[mi], bpm[ni].s9);
+ cpm[ni*VWN + 10][mi] = MultiplyAddVector(cpm[ni*VWN + 10][mi], apm[mi], bpm[ni].sA);
+ cpm[ni*VWN + 11][mi] = MultiplyAddVector(cpm[ni*VWN + 11][mi], apm[mi], bpm[ni].sB);
+ cpm[ni*VWN + 12][mi] = MultiplyAddVector(cpm[ni*VWN + 12][mi], apm[mi], bpm[ni].sC);
+ cpm[ni*VWN + 13][mi] = MultiplyAddVector(cpm[ni*VWN + 13][mi], apm[mi], bpm[ni].sD);
+ cpm[ni*VWN + 14][mi] = MultiplyAddVector(cpm[ni*VWN + 14][mi], apm[mi], bpm[ni].sE);
+ cpm[ni*VWN + 15][mi] = MultiplyAddVector(cpm[ni*VWN + 15][mi], apm[mi], bpm[ni].sF);
+ #endif
+ }
+ }
+}
+
+// =================================================================================================
+
+// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
+// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
+inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
+ const real alpha, const real beta) {
+ #pragma unroll
+ for (int ni=0; ni<NWI; ++ni) {
+ #pragma unroll
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+ #if STRM == 0
+ int mg = mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + mi*MDIMC;
+ #endif
+ #if STRN == 0
+ int ng = ni + get_local_id(1)*NWI;
+ #elif STRN == 1
+ int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
+ #endif
+ int idm = mg + get_group_id(0)*(MWG/VWM);
+ int idn = ng + get_group_id(1)*NWG;
+
+ // The final multiplication with alpha and the addition with beta*C
+ int index = idn*(kSizeM/VWM) + idm;
+ realM cval = cgm[index];
+ #if VWM == 1
+ AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval);
+ #elif VWM == 2
+ AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
+ AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
+ #elif VWM == 4
+ AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
+ AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
+ AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z);
+ AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w);
+ #elif VWM == 8
+ AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
+ AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
+ AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
+ AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
+ AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
+ AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
+ AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
+ AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
+ #elif VWM == 16
+ AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
+ AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
+ AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
+ AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
+ AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
+ AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
+ AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
+ AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
+ AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8);
+ AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9);
+ AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA);
+ AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB);
+ AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC);
+ AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD);
+ AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE);
+ AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF);
+ #endif
+ }
+ }
+}
+
+// =================================================================================================
+
+// Main body of the matrix-multiplication algorithm. It calls the (inlined) functions above.
+inline void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
+ const __global realM* restrict agm, const __global realN* restrict bgm,
+ __global realM* cgm, realM cpm[NWI][MWI/VWM]
+ #if SA == 1 && SB == 1
+ , __local realM* alm, __local realN* blm
+ #elif SA == 1
+ , __local realM* alm
+ #elif SB == 1
+ , __local realN* blm
+ #endif
+ ) {
+
+ // Allocates workitem-private memory (registers)
+ realM apm[MWI/VWM];
+ realN bpm[NWI/VWN];
+
+ // Combined thread identifier (volatile to disable caching)
+ #if SA == 1 || SB == 1
+ volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
+ #endif
+
+ // Initializes the accumulation registers
+ InitAccRegisters(cpm);
+
+ // Loops over all workgroup tiles
+ for (int kwg=0; kwg<kSizeK; kwg+=KWG) {
+
+ // Loads data: off-chip --> local (matrix A)
+ #if SA == 1
+ GlobalToLocalA(agm, alm, kSizeM, tid, kwg);
+ #endif
+ // Loads data: off-chip --> local (matrix B)
+ #if SB == 1
+ GlobalToLocalB(bgm, blm, kSizeN, tid, kwg);
+ #endif
+ #if SA == 1 || SB == 1
+ barrier(CLK_LOCAL_MEM_FENCE);
+ #endif
+
+ // Loops over all workitem tiles, unrolled by a factor KWI
+ for (int pwi=0; pwi<KWG; pwi+=KWI) {
+ #pragma unroll
+ for (int pit=0; pit<KWI; ++pit) {
+ #if SA == 0 || SB == 0
+ int idk = kwg + pwi + pit;
+ #endif
+ #if SA == 1 || SB == 1
+ int kg = pwi+pit;
+ #endif
+
+ // Loads data: local --> private (matrix A)
+ #if SA == 1
+ LocalToPrivateA(alm, apm, kg);
+ // Loads data: off-chip --> private (matrix A)
+ #else
+ GlobalToPrivateA(agm, apm, kSizeM, idk, kwg);
+ #endif
+
+ // Loads data: local --> private (matrix B)
+ #if SB == 1
+ LocalToPrivateB(blm, bpm, kg);
+ // Loads data: off-chip --> private (matrix B)
+ #else
+ GlobalToPrivateB(bgm, bpm, kSizeN, idk);
+ #endif
+
+ // Performs the accumulation (Cpm += Apm * Bpm)
+ MultiplyAccumulate(cpm, apm, bpm);
+ }
+ }
+ #if SA == 1 || SB == 1
+ barrier(CLK_LOCAL_MEM_FENCE);
+ #endif
+ }
+}
+
+// =================================================================================================
+// The upper-triangular and lower-triangular kernels are only used in special cases
+#if defined(ROUTINE_SYRK) || defined(ROUTINE_HERK) || defined(ROUTINE_SYR2K) || defined(ROUTINE_HER2K)
+
+// Main entry point of the kernel. This is the upper-triangular version.
+__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+__kernel void XgemmUpper(const int kSizeN, const int kSizeK,
+ const real alpha, const real beta,
+ const __global realM* restrict agm,
+ const __global realN* restrict bgm,
+ __global realM* cgm) {
+
+ // Skip these threads if they do not contain threads contributing to the upper-triangle
+ if (get_group_id(1)*NWG < get_group_id(0)*MWG) {
+ return;
+ }
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in register memory
+ realM cpm[NWI][MWI/VWM];
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ #elif SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ #else
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ #endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeN, alpha, beta);
+}
+
+// Main entry point of the kernel. This is the lower-triangular version.
+__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+__kernel void XgemmLower(const int kSizeN, const int kSizeK,
+ const real alpha, const real beta,
+ const __global realM* restrict agm,
+ const __global realN* restrict bgm,
+ __global realM* cgm) {
+
+ // Skip these threads if they do not contain threads contributing to the lower-triangle
+ if (get_group_id(1)*NWG > get_group_id(0)*MWG) {
+ return;
+ }
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in register memory
+ realM cpm[NWI][MWI/VWM];
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ #elif SB == 1
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ #else
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ #endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeN, alpha, beta);
+}
+
+// =================================================================================================
+// If not using a triangular version, include the regular kernel
+#else
+
+// Main entry point of the kernel. This is the regular full version.
+__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real alpha, const real beta,
+ const __global realM* restrict agm,
+ const __global realN* restrict bgm,
+ __global realM* cgm) {
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in register memory
+ realM cpm[NWI][MWI/VWM];
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ #elif SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ #else
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ #endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeM, alpha, beta);
+}
+
+#endif
+// =================================================================================================
+
+// End of the C++11 raw string literal
+)"
+
+// =================================================================================================