diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-09-15 21:47:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-15 21:47:04 +0200 |
commit | c163868e1822a97750b4380f0d9cdd38369f9f0b (patch) | |
tree | fbdbc5b5697be8cf4237cc11f87d0b1649f4190d /src/kernels | |
parent | 91dbd580ab2f5d2363d51ba4e3fc9735f1c7a937 (diff) | |
parent | 0f6dd01e513db036191786aed3d03a77e2e8c5dc (diff) |
Merge pull request #318 from CNugteren/CLBlast-315-preprocessor-gemmk1-issue
Fixed pre-processor issues with the new GEMMK=1 kernel
Diffstat (limited to 'src/kernels')
-rw-r--r-- | src/kernels/level3/xgemm_part1.opencl | 6 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part2.opencl | 6 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part3.opencl | 73 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part4.opencl | 4 |
4 files changed, 37 insertions, 52 deletions
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl index 3cfc5dfb..80a60107 100644 --- a/src/kernels/level3/xgemm_part1.opencl +++ b/src/kernels/level3/xgemm_part1.opencl @@ -43,8 +43,6 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= - // Parameters set by the tuner or by the database. Here they are given a basic default value in case // this kernel file is used outside of the CLBlast library. #ifndef GEMMK @@ -397,9 +395,7 @@ INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int } #endif -// ================================================================================================= - -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl index 17c8955a..ee4d5da5 100644 --- a/src/kernels/level3/xgemm_part2.opencl +++ b/src/kernels/level3/xgemm_part2.opencl @@ -15,8 +15,6 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= - // The vectorised multiply-add function INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bval) { #if USE_VECTOR_MAD == 1 @@ -171,9 +169,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM c_value, const int _mi, cgm[index] = result; } -// ================================================================================================= - -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index 90de0b3b..77964a94 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -15,14 +15,12 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= - // A common interface for subgroup functions #if USE_SUBGROUP_SHUFFLING == 1 INLINE_FUNC int clblast_get_sub_group_local_id() { - + // Intel extension #if SUBGROUP_SHUFFLING_INTEL == 1 return get_sub_group_local_id(); @@ -36,7 +34,7 @@ INLINE_FUNC int clblast_get_sub_group_local_id() { } INLINE_FUNC realN clblast_sub_group_shuffle(realN reg, int src) { - + // Intel extension #if SUBGROUP_SHUFFLING_INTEL == 1 return intel_sub_group_shuffle(reg, src); @@ -238,48 +236,47 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { #pragma unroll for (int _ki = 0; _ki < KREG/VWN; _ki += 1) { - const int index = _ni * (MWI/VWM) + _mi; #if USE_SUBGROUP_SHUFFLING == 1 const realN aval = clblast_sub_group_shuffle(apm[_ki], _ni); #else const realN aval = apm[_ni * (KREG/VWN) + _ki]; #endif #if VWN == 1 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval); #elif VWN == 2 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); #elif VWN == 4 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w); #elif VWN == 8 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7); #elif VWN == 16 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF); #endif } } @@ -311,9 +308,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, } } -// ================================================================================================= - -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part4.opencl b/src/kernels/level3/xgemm_part4.opencl index e581cd84..b1f1ade6 100644 --- a/src/kernels/level3/xgemm_part4.opencl +++ b/src/kernels/level3/xgemm_part4.opencl @@ -15,7 +15,6 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= // The upper-triangular and lower-triangular kernels are only used in special cases #if defined(ROUTINE_SYRK) || defined(ROUTINE_HERK) || defined(ROUTINE_SYR2K) || defined(ROUTINE_HER2K) @@ -132,9 +131,8 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, } #endif -// ================================================================================================= -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= |