From 540896476d62ce37e7a939d185c15dc930b8a343 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 7 Dec 2017 22:05:29 +0100 Subject: Added register promotion to the main GEMM kernel --- src/kernel_preprocessor.cpp | 4 +-- src/kernels/level3/xgemm_part1.opencl | 64 ++++++++++++++++----------------- src/kernels/level3/xgemm_part2.opencl | 68 +++++++++++++++++------------------ src/kernels/level3/xgemm_part3.opencl | 13 ++++--- 4 files changed, 77 insertions(+), 72 deletions(-) (limited to 'src') diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index 6239361b..493c009c 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -50,8 +50,8 @@ bool HasOnlyDigits(const std::string& str) { int ParseMath(const std::string& str) { // Handles brackets - const auto split_close = split(str, ')'); - if (split_close.size() >= 2) { + if (str.find(")") != std::string::npos) { + const auto split_close = split(str, ')'); const auto split_end = split(split_close[0], '('); if (split_end.size() < 2) { RaiseError(str, "Mismatching brackets #0"); } const auto bracket_contents = ParseMath(split_end[split_end.size() - 1]); diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl index e118ba2f..88744668 100644 --- a/src/kernels/level3/xgemm_part1.opencl +++ b/src/kernels/level3/xgemm_part1.opencl @@ -135,47 +135,47 @@ R"( // ================================================================================================= // Initializes the accumulation registers to zero -INLINE_FUNC void InitAccRegisters(realM cpm[NWI][MWI/VWM]) { +INLINE_FUNC void InitAccRegisters(realM cpm[NWI*MWI/VWM]) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { #pragma unroll for (int _ni = 0; _ni < NWI; _ni += 1) { #if VWM == 1 - SetToZero(cpm[_ni][_mi]); + SetToZero(cpm[_ni * (MWI/VWM) + _mi]); #elif VWM == 2 - SetToZero(cpm[_ni][_mi].x); - SetToZero(cpm[_ni][_mi].y); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].x); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].y); #elif VWM == 4 - SetToZero(cpm[_ni][_mi].x); - SetToZero(cpm[_ni][_mi].y); - SetToZero(cpm[_ni][_mi].z); - SetToZero(cpm[_ni][_mi].w); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].x); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].y); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].z); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].w); #elif VWM == 8 - SetToZero(cpm[_ni][_mi].s0); - SetToZero(cpm[_ni][_mi].s1); - SetToZero(cpm[_ni][_mi].s2); - SetToZero(cpm[_ni][_mi].s3); - SetToZero(cpm[_ni][_mi].s4); - SetToZero(cpm[_ni][_mi].s5); - SetToZero(cpm[_ni][_mi].s6); - SetToZero(cpm[_ni][_mi].s7); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7); #elif VWM == 16 - SetToZero(cpm[_ni][_mi].s0); - SetToZero(cpm[_ni][_mi].s1); - SetToZero(cpm[_ni][_mi].s2); - SetToZero(cpm[_ni][_mi].s3); - SetToZero(cpm[_ni][_mi].s4); - SetToZero(cpm[_ni][_mi].s5); - SetToZero(cpm[_ni][_mi].s6); - SetToZero(cpm[_ni][_mi].s7); - SetToZero(cpm[_ni][_mi].s8); - SetToZero(cpm[_ni][_mi].s9); - SetToZero(cpm[_ni][_mi].sA); - SetToZero(cpm[_ni][_mi].sB); - SetToZero(cpm[_ni][_mi].sC); - SetToZero(cpm[_ni][_mi].sD); - SetToZero(cpm[_ni][_mi].sE); - SetToZero(cpm[_ni][_mi].sF); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s8); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].s9); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].sA); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].sB); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].sC); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].sD); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].sE); + SetToZero(cpm[_ni * (MWI/VWM) + _mi].sF); #endif } } diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl index a5507458..88100e96 100644 --- a/src/kernels/level3/xgemm_part2.opencl +++ b/src/kernels/level3/xgemm_part2.opencl @@ -64,48 +64,48 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva } // Performs the actual computation: Cpm += Apm * Bpm -INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) { +INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI*MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) { #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { const realM aval = apm[_mi]; #if VWN == 1 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni]); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); #elif VWN == 2 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x); - cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); #elif VWN == 4 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x); - cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y); - cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].z); - cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].w); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); #elif VWN == 8 - cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].s0); - cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].s1); - cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].s2); - cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].s3); - cpm[_ni*VWN + 4][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4][_mi], aval, bpm[_ni].s4); - cpm[_ni*VWN + 5][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5][_mi], aval, bpm[_ni].s5); - cpm[_ni*VWN + 6][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6][_mi], aval, bpm[_ni].s6); - cpm[_ni*VWN + 7][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7][_mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); #elif VWN == 16 - cpm[_ni*VWN + 0 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0 ][_mi], aval, bpm[_ni].s0); - cpm[_ni*VWN + 1 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1 ][_mi], aval, bpm[_ni].s1); - cpm[_ni*VWN + 2 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2 ][_mi], aval, bpm[_ni].s2); - cpm[_ni*VWN + 3 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3 ][_mi], aval, bpm[_ni].s3); - cpm[_ni*VWN + 4 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4 ][_mi], aval, bpm[_ni].s4); - cpm[_ni*VWN + 5 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5 ][_mi], aval, bpm[_ni].s5); - cpm[_ni*VWN + 6 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6 ][_mi], aval, bpm[_ni].s6); - cpm[_ni*VWN + 7 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7 ][_mi], aval, bpm[_ni].s7); - cpm[_ni*VWN + 8 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 8 ][_mi], aval, bpm[_ni].s8); - cpm[_ni*VWN + 9 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 9 ][_mi], aval, bpm[_ni].s9); - cpm[_ni*VWN + 10][_mi] = MultiplyAddVector(cpm[_ni*VWN + 10][_mi], aval, bpm[_ni].sA); - cpm[_ni*VWN + 11][_mi] = MultiplyAddVector(cpm[_ni*VWN + 11][_mi], aval, bpm[_ni].sB); - cpm[_ni*VWN + 12][_mi] = MultiplyAddVector(cpm[_ni*VWN + 12][_mi], aval, bpm[_ni].sC); - cpm[_ni*VWN + 13][_mi] = MultiplyAddVector(cpm[_ni*VWN + 13][_mi], aval, bpm[_ni].sD); - cpm[_ni*VWN + 14][_mi] = MultiplyAddVector(cpm[_ni*VWN + 14][_mi], aval, bpm[_ni].sE); - cpm[_ni*VWN + 15][_mi] = MultiplyAddVector(cpm[_ni*VWN + 15][_mi], aval, bpm[_ni].sF); + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); #endif } } @@ -115,7 +115,7 @@ INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], // Merges the results in Cpm with the global array in Cgm. This also performs the multiplication // with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm -INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM, +INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI*MWI/VWM], const int kSizeM, const real alpha, const real beta) { #pragma unroll for (int _ni = 0; _ni < NWI; _ni += 1) { @@ -136,7 +136,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], cons int index = idn*(kSizeM/VWM) + idm; realM result; - realM xval = cpm[_ni][_mi]; + realM xval = cpm[_ni * (MWI/VWM) + _mi]; // The final multiplication with alpha (in case beta == 0) if (IsZero(beta)) { diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index 4e85c4a8..7e46cef5 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -20,7 +20,7 @@ R"( // Main body of the matrix-multiplication algorithm. It calls various (inlined) functions. INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, const __global realM* restrict agm, const __global realN* restrict bgm, - __global realM* cgm, realM cpm[NWI][MWI/VWM] + __global realM* cgm, realM cpm[NWI*MWI/VWM] #if SA == 1 && SB == 1 , LOCAL_PTR realM* alm, LOCAL_PTR realN* blm #elif SA == 1 @@ -31,7 +31,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, ) { // Allocates workitem-private memory (registers) + #pragma promote_to_registers realM apm[MWI/VWM]; + #pragma promote_to_registers realN bpm[NWI/VWN]; // Combined thread identifier (volatile to disable caching) @@ -126,7 +128,8 @@ void XgemmUpper(const int kSizeN, const int kSizeK, #endif // Computes the matrix-multiplication and stores the result in register memory - realM cpm[NWI][MWI/VWM]; + #pragma promote_to_registers + realM cpm[NWI*(MWI/VWM)]; #if SA == 1 && SB == 1 XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); #elif SA == 1 @@ -166,7 +169,8 @@ void XgemmLower(const int kSizeN, const int kSizeK, #endif // Computes the matrix-multiplication and stores the result in register memory - realM cpm[NWI][MWI/VWM]; + #pragma promote_to_registers + realM cpm[NWI*(MWI/VWM)]; #if SA == 1 && SB == 1 XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); #elif SA == 1 @@ -210,7 +214,8 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Computes the matrix-multiplication and stores the result in register memory - realM cpm[NWI][MWI/VWM]; + #pragma promote_to_registers + realM cpm[NWI*(MWI/VWM)]; #if SA == 1 && SB == 1 XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); #elif SA == 1 -- cgit v1.2.3