diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-06-23 17:58:51 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-06-23 17:58:51 +0200 |
commit | 9fc38cdf5ed44ef41cf3d6cf9e7c32585447c042 (patch) | |
tree | 6f3163b61d780f6b25e21cc34dda68e76a2de637 /src/kernels | |
parent | 0a3831e6d1eb437a9ef9ac7570f9a554b2c35edb (diff) |
Added a lower/upper triangular version of the GEMM kernel
Diffstat (limited to 'src/kernels')
-rw-r--r-- | src/kernels/xgemm.opencl | 371 |
1 files changed, 240 insertions, 131 deletions
diff --git a/src/kernels/xgemm.opencl b/src/kernels/xgemm.opencl index a4f45e90..4c7ae064 100644 --- a/src/kernels/xgemm.opencl +++ b/src/kernels/xgemm.opencl @@ -127,6 +127,55 @@ R"( // ================================================================================================= +// Initializes the accumulation registers to zero +inline void InitAccRegisters(realM cpm[NWI][MWI/VWM]) { + #pragma unroll + for (int mi=0; mi<MWI/VWM; ++mi) { + #pragma unroll + for (int ni=0; ni<NWI; ++ni) { + #if VWM == 1 + SetToZero(cpm[ni][mi]); + #elif VWM == 2 + SetToZero(cpm[ni][mi].x); + SetToZero(cpm[ni][mi].y); + #elif VWM == 4 + SetToZero(cpm[ni][mi].x); + SetToZero(cpm[ni][mi].y); + SetToZero(cpm[ni][mi].z); + SetToZero(cpm[ni][mi].w); + #elif VWM == 8 + SetToZero(cpm[ni][mi].s0); + SetToZero(cpm[ni][mi].s1); + SetToZero(cpm[ni][mi].s2); + SetToZero(cpm[ni][mi].s3); + SetToZero(cpm[ni][mi].s4); + SetToZero(cpm[ni][mi].s5); + SetToZero(cpm[ni][mi].s6); + SetToZero(cpm[ni][mi].s7); + #elif VWM == 16 + SetToZero(cpm[ni][mi].s0); + SetToZero(cpm[ni][mi].s1); + SetToZero(cpm[ni][mi].s2); + SetToZero(cpm[ni][mi].s3); + SetToZero(cpm[ni][mi].s4); + SetToZero(cpm[ni][mi].s5); + SetToZero(cpm[ni][mi].s6); + SetToZero(cpm[ni][mi].s7); + SetToZero(cpm[ni][mi].s8); + SetToZero(cpm[ni][mi].s9); + SetToZero(cpm[ni][mi].sA); + SetToZero(cpm[ni][mi].sB); + SetToZero(cpm[ni][mi].sC); + SetToZero(cpm[ni][mi].sD); + SetToZero(cpm[ni][mi].sE); + SetToZero(cpm[ni][mi].sF); + #endif + } + } +} + +// ================================================================================================= + // Caches global off-chip memory into local (shared) memory on-chip. This function is specific for // caching the A input matrix. #if SA == 1 @@ -272,71 +321,6 @@ inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg // ================================================================================================= -// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication -// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm -inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM, - const real alpha, const real beta) { - #pragma unroll - for (int ni=0; ni<NWI; ++ni) { - #pragma unroll - for (int mi=0; mi<MWI/VWM; ++mi) { - #if STRM == 0 - int mg = mi + get_local_id(0)*(MWI/VWM); - #elif STRM == 1 - int mg = get_local_id(0) + mi*MDIMC; - #endif - #if STRN == 0 - int ng = ni + get_local_id(1)*NWI; - #elif STRN == 1 - int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC; - #endif - int idm = mg + get_group_id(0)*(MWG/VWM); - int idn = ng + get_group_id(1)*NWG; - int index = idn*(kSizeM/VWM) + idm; - realM cval = cgm[index]; - #if VWM == 1 - AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval); - #elif VWM == 2 - AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); - AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); - #elif VWM == 4 - AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); - AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); - AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z); - AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w); - #elif VWM == 8 - AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); - AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); - AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); - AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); - AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); - AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); - AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); - AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); - #elif VWM == 16 - AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); - AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); - AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); - AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); - AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); - AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); - AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); - AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); - AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8); - AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9); - AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA); - AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB); - AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC); - AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD); - AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE); - AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF); - #endif - } - } -} - -// ================================================================================================= - // The vectorised multiply-add function inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) { #if USE_VECTOR_MAD == 1 @@ -432,77 +416,97 @@ inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], real // ================================================================================================= -// Main entry of the kernel. This function contains the basic skeleton, the functionality is -// provided by the inlined functions above -__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) -__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, - const real alpha, const real beta, - const __global realM* restrict agm, - const __global realN* restrict bgm, - __global realM* cgm) { - - // Combined thread identifier - #if SA == 1 || SB == 1 - volatile int tid = get_local_id(0) + MDIMC*get_local_id(1); - #endif - - // Allocates workgroup-private memory (local memory) - #if SA == 1 - __local realM alm[KWG * MWG/VWM]; - #endif - #if SB == 1 - __local realN blm[KWG * NWG/VWN]; - #endif - - // Allocates workitem-private memory (registers) - realM apm[MWI/VWM]; - realN bpm[NWI/VWN]; - realM cpm[NWI][MWI/VWM]; - - // Initializes the accumulation registers +// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication +// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm +inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM, + const real alpha, const real beta) { #pragma unroll - for (int mi=0; mi<MWI/VWM; ++mi) { + for (int ni=0; ni<NWI; ++ni) { #pragma unroll - for (int ni=0; ni<NWI; ++ni) { + for (int mi=0; mi<MWI/VWM; ++mi) { + #if STRM == 0 + int mg = mi + get_local_id(0)*(MWI/VWM); + #elif STRM == 1 + int mg = get_local_id(0) + mi*MDIMC; + #endif + #if STRN == 0 + int ng = ni + get_local_id(1)*NWI; + #elif STRN == 1 + int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC; + #endif + int idm = mg + get_group_id(0)*(MWG/VWM); + int idn = ng + get_group_id(1)*NWG; + + // The final multiplication with alpha and the addition with beta*C + int index = idn*(kSizeM/VWM) + idm; + realM cval = cgm[index]; #if VWM == 1 - SetToZero(cpm[ni][mi]); + AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval); #elif VWM == 2 - SetToZero(cpm[ni][mi].x); - SetToZero(cpm[ni][mi].y); + AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); + AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); #elif VWM == 4 - SetToZero(cpm[ni][mi].x); - SetToZero(cpm[ni][mi].y); - SetToZero(cpm[ni][mi].z); - SetToZero(cpm[ni][mi].w); + AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); + AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); + AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z); + AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w); #elif VWM == 8 - SetToZero(cpm[ni][mi].s0); - SetToZero(cpm[ni][mi].s1); - SetToZero(cpm[ni][mi].s2); - SetToZero(cpm[ni][mi].s3); - SetToZero(cpm[ni][mi].s4); - SetToZero(cpm[ni][mi].s5); - SetToZero(cpm[ni][mi].s6); - SetToZero(cpm[ni][mi].s7); + AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); + AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); + AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); + AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); + AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); + AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); + AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); + AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); #elif VWM == 16 - SetToZero(cpm[ni][mi].s0); - SetToZero(cpm[ni][mi].s1); - SetToZero(cpm[ni][mi].s2); - SetToZero(cpm[ni][mi].s3); - SetToZero(cpm[ni][mi].s4); - SetToZero(cpm[ni][mi].s5); - SetToZero(cpm[ni][mi].s6); - SetToZero(cpm[ni][mi].s7); - SetToZero(cpm[ni][mi].s8); - SetToZero(cpm[ni][mi].s9); - SetToZero(cpm[ni][mi].sA); - SetToZero(cpm[ni][mi].sB); - SetToZero(cpm[ni][mi].sC); - SetToZero(cpm[ni][mi].sD); - SetToZero(cpm[ni][mi].sE); - SetToZero(cpm[ni][mi].sF); + AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); + AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); + AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); + AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); + AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); + AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); + AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); + AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); + AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8); + AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9); + AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA); + AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB); + AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC); + AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD); + AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE); + AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF); #endif } } +} + +// ================================================================================================= + +// Main body of the matrix-multiplication algorithm. It calls the (inlined) functions above. +inline void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, + const __global realM* restrict agm, const __global realN* restrict bgm, + __global realM* cgm, realM cpm[NWI][MWI/VWM], + #if SA == 1 && SB == 1 + __local realM* alm, __local realN* blm + #elif SA == 1 + __local realM* alm + #elif SB == 1 + __local realN* blm + #endif + ) { + + // Allocates workitem-private memory (registers) + realM apm[MWI/VWM]; + realN bpm[NWI/VWN]; + + // Combined thread identifier (volatile to disable caching) + #if SA == 1 || SB == 1 + volatile int tid = get_local_id(0) + MDIMC*get_local_id(1); + #endif + + // Initializes the accumulation registers + InitAccRegisters(cpm); // Loops over all workgroup tiles for (int kwg=0; kwg<kSizeK; kwg+=KWG) { @@ -515,8 +519,6 @@ __kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, #if SB == 1 GlobalToLocalB(bgm, blm, kSizeN, tid, kwg); #endif - - // Synchronizes all threads in a workgroup #if SA == 1 || SB == 1 barrier(CLK_LOCAL_MEM_FENCE); #endif @@ -552,19 +554,126 @@ __kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, MultiplyAccumulate(cpm, apm, bpm); } } - - // Synchronizes all threads in a workgroup #if SA == 1 || SB == 1 barrier(CLK_LOCAL_MEM_FENCE); #endif } +} - // Stores an MWG * NWG tile of results and perform the multiplication with alpha and beta +// ================================================================================================= + +// Main entry point of the kernel. This is the regular full version. +__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) +__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, + const real alpha, const real beta, + const __global realM* restrict agm, + const __global realN* restrict bgm, + __global realM* cgm) { + + // Allocates workgroup-private memory (local memory) + #if SA == 1 + __local realM alm[KWG * MWG/VWM]; + #endif + #if SB == 1 + __local realN blm[KWG * NWG/VWN]; + #endif + + // Computes the matrix-multiplication and stores the result in register memory + realM cpm[NWI][MWI/VWM]; + #if SA == 1 && SB == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + #elif SA == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + #elif SB == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + #else + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm); + #endif + + // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta StoreResults(cgm, cpm, kSizeM, alpha, beta); } // ================================================================================================= +// Main entry point of the kernel. This is the upper-triangular version. +__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) +__kernel void XgemmUpper(const int kSizeN, const int kSizeK, + const real alpha, const real beta, + const __global realM* restrict agm, + const __global realN* restrict bgm, + __global realM* cgm) { + + // Skip these threads if they do not contain threads contributing to the upper-triangle + if (get_group_id(1)*NWG < get_group_id(0)*MWG) { + return; + } + + // Allocates workgroup-private memory (local memory) + #if SA == 1 + __local realM alm[KWG * MWG/VWM]; + #endif + #if SB == 1 + __local realN blm[KWG * NWG/VWN]; + #endif + + // Computes the matrix-multiplication and stores the result in register memory + realM cpm[NWI][MWI/VWM]; + #if SA == 1 && SB == 1 + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + #elif SA == 1 + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + #elif SB == 1 + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + #else + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm); + #endif + + // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta + StoreResults(cgm, cpm, kSizeN, alpha, beta); +} + +// ================================================================================================= + +// Main entry point of the kernel. This is the lower-triangular version. +__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) +__kernel void XgemmLower(const int kSizeN, const int kSizeK, + const real alpha, const real beta, + const __global realM* restrict agm, + const __global realN* restrict bgm, + __global realM* cgm) { + + // Skip these threads if they do not contain threads contributing to the lower-triangle + if (get_group_id(1)*NWG > get_group_id(0)*MWG) { + return; + } + + // Allocates workgroup-private memory (local memory) + #if SA == 1 + __local realM alm[KWG * MWG/VWM]; + #endif + #if SB == 1 + __local realN blm[KWG * NWG/VWN]; + #endif + + // Computes the matrix-multiplication and stores the result in register memory + realM cpm[NWI][MWI/VWM]; + #if SA == 1 && SB == 1 + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + #elif SA == 1 + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + #elif SB == 1 + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + #else + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm); + #endif + + // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta + StoreResults(cgm, cpm, kSizeN, alpha, beta); +} + +// ================================================================================================= + // End of the C++11 raw string literal )"; |