summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_part2.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level3/xgemm_part2.opencl')
-rw-r--r--src/kernels/level3/xgemm_part2.opencl324
1 files changed, 81 insertions, 243 deletions
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index 42c1127c..e8234a29 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -7,7 +7,7 @@
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
-// This is part 2 of 2 of the GEMM kernel. See part 1 for more information.
+// This is part 2 of 3 of the GEMM kernel. See part 1 for more information.
//
// =================================================================================================
@@ -133,260 +133,98 @@ inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int
#endif
int idm = mg + GetGroupID0() * (MWG/VWM);
int idn = ng + GetGroupID1() * NWG;
-
- // The final multiplication with alpha and the addition with beta*C
int index = idn*(kSizeM/VWM) + idm;
+
realM result;
realM xval = cpm[ni][mi];
- realM yval = cgm[index];
- #if VWM == 1
- AXPBY(result, alpha, xval, beta, yval);
- #elif VWM == 2
- AXPBY(result.x, alpha, xval.x, beta, yval.x);
- AXPBY(result.y, alpha, xval.y, beta, yval.y);
- #elif VWM == 4
- AXPBY(result.x, alpha, xval.x, beta, yval.x);
- AXPBY(result.y, alpha, xval.y, beta, yval.y);
- AXPBY(result.z, alpha, xval.z, beta, yval.z);
- AXPBY(result.w, alpha, xval.w, beta, yval.w);
- #elif VWM == 8
- AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
- AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
- AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
- AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
- AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
- AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
- AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
- AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
- #elif VWM == 16
- AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
- AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
- AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
- AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
- AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
- AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
- AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
- AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
- AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
- AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
- AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
- AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
- AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
- AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
- AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
- AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
- #endif
- cgm[index] = result;
- }
- }
-}
-
-// =================================================================================================
-
-// Main body of the matrix-multiplication algorithm. It calls the (inlined) functions above.
-inline void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
- const __global realM* restrict agm, const __global realN* restrict bgm,
- __global realM* cgm, realM cpm[NWI][MWI/VWM]
- #if SA == 1 && SB == 1
- , __local realM* alm, __local realN* blm
- #elif SA == 1
- , __local realM* alm
- #elif SB == 1
- , __local realN* blm
- #endif
- ) {
-
- // Allocates workitem-private memory (registers)
- realM apm[MWI/VWM];
- realN bpm[NWI/VWN];
-
- // Combined thread identifier (volatile to disable caching)
- #if SA == 1 || SB == 1
- volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
- #endif
-
- // Initializes the accumulation registers
- InitAccRegisters(cpm);
-
- // Loops over all workgroup tiles
- for (int kwg=0; kwg<kSizeK; kwg+=KWG) {
- // Loads data: off-chip --> local (matrix A)
- #if SA == 1
- GlobalToLocalA(agm, alm, kSizeM, tid, kwg);
- #endif
- // Loads data: off-chip --> local (matrix B)
- #if SB == 1
- GlobalToLocalB(bgm, blm, kSizeN, tid, kwg);
- #endif
- #if SA == 1 || SB == 1
- barrier(CLK_LOCAL_MEM_FENCE);
- #endif
-
- // Loops over all workitem tiles, unrolled by a factor KWI
- for (int pwi=0; pwi<KWG; pwi+=KWI) {
- #pragma unroll
- for (int pit=0; pit<KWI; ++pit) {
- #if SA == 0 || SB == 0
- int idk = kwg + pwi + pit;
- #endif
- #if SA == 1 || SB == 1
- int kg = pwi+pit;
- #endif
-
- // Loads data: local --> private (matrix A)
- #if SA == 1
- LocalToPrivateA(alm, apm, kg);
- // Loads data: off-chip --> private (matrix A)
- #else
- GlobalToPrivateA(agm, apm, kSizeM, idk, kwg);
+ // The final multiplication with alpha (in case beta == 0)
+ if (IsZero(beta)) {
+ #if VWM == 1
+ Multiply(result, alpha, xval);
+ #elif VWM == 2
+ Multiply(result.x, alpha, xval.x);
+ Multiply(result.y, alpha, xval.y);
+ #elif VWM == 4
+ Multiply(result.x, alpha, xval.x);
+ Multiply(result.y, alpha, xval.y);
+ Multiply(result.z, alpha, xval.z);
+ Multiply(result.w, alpha, xval.w);
+ #elif VWM == 8
+ Multiply(result.s0, alpha, xval.s0);
+ Multiply(result.s1, alpha, xval.s1);
+ Multiply(result.s2, alpha, xval.s2);
+ Multiply(result.s3, alpha, xval.s3);
+ Multiply(result.s4, alpha, xval.s4);
+ Multiply(result.s5, alpha, xval.s5);
+ Multiply(result.s6, alpha, xval.s6);
+ Multiply(result.s7, alpha, xval.s7);
+ #elif VWM == 16
+ Multiply(result.s0, alpha, xval.s0);
+ Multiply(result.s1, alpha, xval.s1);
+ Multiply(result.s2, alpha, xval.s2);
+ Multiply(result.s3, alpha, xval.s3);
+ Multiply(result.s4, alpha, xval.s4);
+ Multiply(result.s5, alpha, xval.s5);
+ Multiply(result.s6, alpha, xval.s6);
+ Multiply(result.s7, alpha, xval.s7);
+ Multiply(result.s8, alpha, xval.s8);
+ Multiply(result.s9, alpha, xval.s9);
+ Multiply(result.sA, alpha, xval.sA);
+ Multiply(result.sB, alpha, xval.sB);
+ Multiply(result.sC, alpha, xval.sC);
+ Multiply(result.sD, alpha, xval.sD);
+ Multiply(result.sE, alpha, xval.sE);
+ Multiply(result.sF, alpha, xval.sF);
#endif
+ }
- // Loads data: local --> private (matrix B)
- #if SB == 1
- LocalToPrivateB(blm, bpm, kg);
- // Loads data: off-chip --> private (matrix B)
- #else
- GlobalToPrivateB(bgm, bpm, kSizeN, idk);
+ // The final multiplication with alpha and the addition with beta*C
+ else {
+ realM yval = cgm[index];
+ #if VWM == 1
+ AXPBY(result, alpha, xval, beta, yval);
+ #elif VWM == 2
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ #elif VWM == 4
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ AXPBY(result.z, alpha, xval.z, beta, yval.z);
+ AXPBY(result.w, alpha, xval.w, beta, yval.w);
+ #elif VWM == 8
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ #elif VWM == 16
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
+ AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
+ AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
+ AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
+ AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
+ AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
+ AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
+ AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
#endif
-
- // Performs the accumulation (Cpm += Apm * Bpm)
- MultiplyAccumulate(cpm, apm, bpm);
}
+ cgm[index] = result;
}
- #if SA == 1 || SB == 1
- barrier(CLK_LOCAL_MEM_FENCE);
- #endif
- }
- #if GLOBAL_MEM_FENCE == 1
- barrier(CLK_GLOBAL_MEM_FENCE);
- #endif
-}
-
-// =================================================================================================
-// The upper-triangular and lower-triangular kernels are only used in special cases
-#if defined(ROUTINE_SYRK) || defined(ROUTINE_HERK) || defined(ROUTINE_SYR2K) || defined(ROUTINE_HER2K)
-
-// Main entry point of the kernel. This is the upper-triangular version.
-__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
-__kernel void XgemmUpper(const int kSizeN, const int kSizeK,
- const __constant real* restrict arg_alpha,
- const __constant real* restrict arg_beta,
- const __global realM* restrict agm,
- const __global realN* restrict bgm,
- __global realM* cgm) {
- const real alpha = arg_alpha[0];
- const real beta = arg_beta[0];
-
- // Skip these threads if they do not contain threads contributing to the upper-triangle
- if (GetGroupID1()*NWG < GetGroupID0()*MWG) {
- return;
- }
-
- // Allocates workgroup-private memory (local memory)
- #if SA == 1
- __local realM alm[KWG * MWG/VWM];
- #endif
- #if SB == 1
- __local realN blm[KWG * NWG/VWN];
- #endif
-
- // Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
- #if SA == 1 && SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
- #elif SA == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
- #elif SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
- #else
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
- #endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeN, alpha, beta);
-}
-
-// Main entry point of the kernel. This is the lower-triangular version.
-__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
-__kernel void XgemmLower(const int kSizeN, const int kSizeK,
- const __constant real* restrict arg_alpha,
- const __constant real* restrict arg_beta,
- const __global realM* restrict agm,
- const __global realN* restrict bgm,
- __global realM* cgm) {
- const real alpha = arg_alpha[0];
- const real beta = arg_beta[0];
-
- // Skip these threads if they do not contain threads contributing to the lower-triangle
- if (GetGroupID1()*NWG > GetGroupID0()*MWG) {
- return;
}
-
- // Allocates workgroup-private memory (local memory)
- #if SA == 1
- __local realM alm[KWG * MWG/VWM];
- #endif
- #if SB == 1
- __local realN blm[KWG * NWG/VWN];
- #endif
-
- // Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
- #if SA == 1 && SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
- #elif SA == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
- #elif SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
- #else
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
- #endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeN, alpha, beta);
-}
-
-// =================================================================================================
-// If not using a triangular version, include the regular kernel
-#else
-
-// Main entry point of the kernel. This is the regular full version.
-__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
-__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
- const __constant real* restrict arg_alpha,
- const __constant real* restrict arg_beta,
- const __global realM* restrict agm,
- const __global realN* restrict bgm,
- __global realM* cgm) {
- const real alpha = arg_alpha[0];
- const real beta = arg_beta[0];
-
- // Allocates workgroup-private memory (local memory)
- #if SA == 1
- __local realM alm[KWG * MWG/VWM];
- #endif
- #if SB == 1
- __local realN blm[KWG * NWG/VWN];
- #endif
-
- // Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
- #if SA == 1 && SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
- #elif SA == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
- #elif SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
- #else
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm);
- #endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
-#endif
// =================================================================================================
// End of the C++11 raw string literal