diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-09 14:09:13 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-09 14:09:13 +0100 |
commit | 23e3a85f2c328d4a23db2fca5d1d89d78163711f (patch) | |
tree | 02b8dd5364d958184c45c9bfdb2c28e38d72b24e /src/kernels/level3/xgemm_part3.opencl | |
parent | d9df62b7942bb8af5fd385b8545aceb1d8b578f3 (diff) |
Reformatted GEMM kernel to support array-to-register promotion
Diffstat (limited to 'src/kernels/level3/xgemm_part3.opencl')
-rw-r--r-- | src/kernels/level3/xgemm_part3.opencl | 143 |
1 files changed, 95 insertions, 48 deletions
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index f12fb304..157c1326 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -20,7 +20,7 @@ R"( // Main body of the matrix-multiplication algorithm. It calls various (inlined) functions. INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, const __global realM* restrict agm, const __global realN* restrict bgm, - __global realM* cgm, realM cpm[NWI*MWI/VWM] + __global realM* cgm, const real alpha, const real beta #if SA == 1 && SB == 1 , LOCAL_PTR realM* alm, LOCAL_PTR realN* blm #elif SA == 1 @@ -31,10 +31,12 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, ) { // Allocates workitem-private memory (registers) - //#pragma promote_to_registers + #pragma promote_to_registers realM apm[MWI/VWM]; - //#pragma promote_to_registers + #pragma promote_to_registers realN bpm[NWI/VWN]; + #pragma promote_to_registers + realM cpm[NWI*(MWI/VWM)]; // Combined thread identifier (volatile to disable caching) #if SA == 1 || SB == 1 @@ -42,7 +44,14 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Initializes the accumulation registers - InitAccRegisters(cpm); + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + #pragma unroll + for (int _ni = 0; _ni < NWI; _ni += 1) { + cpm[_ni * (MWI/VWM) + _mi] = InitAccRegisters(); + } + } + // Loops over all workgroup tiles for (int kwg = 0; kwg < kSizeK; kwg += KWG) { @@ -70,24 +79,74 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, int kg = pwi + _pit; #endif - // Loads data: local --> private (matrix A) - #if SA == 1 - LocalToPrivateA(alm, apm, kg); - // Loads data: off-chip --> private (matrix A) - #else - GlobalToPrivateA(agm, apm, kSizeM, idk, kwg); - #endif + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + // Loads data: local --> private (matrix A) + #if SA == 1 + apm[_mi] = LocalToPrivateA(alm, _mi, kg); + // Loads data: off-chip --> private (matrix A) + #else + apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg); + #endif + } // Loads data: local --> private (matrix B) - #if SB == 1 - LocalToPrivateB(blm, bpm, kg); - // Loads data: off-chip --> private (matrix B) - #else - GlobalToPrivateB(bgm, bpm, kSizeN, idk); - #endif + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #if SB == 1 + bpm[_ni] = LocalToPrivateB(blm, _ni, kg); + // Loads data: off-chip --> private (matrix B) + #else + bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); + #endif + } // Performs the accumulation (Cpm += Apm * Bpm) - MultiplyAccumulate(cpm, apm, bpm); + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + const realM aval = apm[_mi]; + #if VWN == 1 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); + #elif VWN == 2 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + #elif VWN == 4 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); + #elif VWN == 8 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + #elif VWN == 16 + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); + #endif + } + } + } } #if SA == 1 || SB == 1 @@ -97,6 +156,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #if GLOBAL_MEM_FENCE == 1 barrier(CLK_GLOBAL_MEM_FENCE); #endif + + // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta + StoreResults(cgm, cpm, kSizeM, alpha, beta); } // ================================================================================================= @@ -127,21 +189,16 @@ void XgemmUpper(const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - //#pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm); #else - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm, cpm, kSizeN, alpha, beta); } // Main entry point of the kernel. This is the lower-triangular version. @@ -168,21 +225,16 @@ void XgemmLower(const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - //#pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm); #else - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm, cpm, kSizeN, alpha, beta); } // ================================================================================================= @@ -213,21 +265,16 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - //#pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm); #else - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm, cpm, kSizeM, alpha, beta); } #endif |