diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-09-15 21:47:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-15 21:47:04 +0200 |
commit | c163868e1822a97750b4380f0d9cdd38369f9f0b (patch) | |
tree | fbdbc5b5697be8cf4237cc11f87d0b1649f4190d /src | |
parent | 91dbd580ab2f5d2363d51ba4e3fc9735f1c7a937 (diff) | |
parent | 0f6dd01e513db036191786aed3d03a77e2e8c5dc (diff) |
Merge pull request #318 from CNugteren/CLBlast-315-preprocessor-gemmk1-issue
Fixed pre-processor issues with the new GEMMK=1 kernel
Diffstat (limited to 'src')
-rw-r--r-- | src/kernel_preprocessor.cpp | 2 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part1.opencl | 6 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part2.opencl | 6 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part3.opencl | 73 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part4.opencl | 4 | ||||
-rw-r--r-- | src/routines/level3/xgemm.cpp | 1 | ||||
-rw-r--r-- | src/routines/level3/xherk.cpp | 1 | ||||
-rw-r--r-- | src/routines/level3/xsyrk.cpp | 1 | ||||
-rw-r--r-- | src/routines/levelx/xgemmbatched.cpp | 1 | ||||
-rw-r--r-- | src/routines/levelx/xgemmstridedbatched.cpp | 1 | ||||
-rw-r--r-- | src/tuning/kernels/xgemm.hpp | 2 | ||||
-rw-r--r-- | src/utilities/compile.cpp | 3 |
12 files changed, 48 insertions, 53 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index aa946bab..1c422d33 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -557,6 +557,8 @@ std::string PreprocessKernelSource(const std::string& kernel_source) { lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); + lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); + lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true); // Gather the results diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl index 3cfc5dfb..80a60107 100644 --- a/src/kernels/level3/xgemm_part1.opencl +++ b/src/kernels/level3/xgemm_part1.opencl @@ -43,8 +43,6 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= - // Parameters set by the tuner or by the database. Here they are given a basic default value in case // this kernel file is used outside of the CLBlast library. #ifndef GEMMK @@ -397,9 +395,7 @@ INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int } #endif -// ================================================================================================= - -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl index 17c8955a..ee4d5da5 100644 --- a/src/kernels/level3/xgemm_part2.opencl +++ b/src/kernels/level3/xgemm_part2.opencl @@ -15,8 +15,6 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= - // The vectorised multiply-add function INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bval) { #if USE_VECTOR_MAD == 1 @@ -171,9 +169,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM c_value, const int _mi, cgm[index] = result; } -// ================================================================================================= - -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index 90de0b3b..77964a94 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -15,14 +15,12 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= - // A common interface for subgroup functions #if USE_SUBGROUP_SHUFFLING == 1 INLINE_FUNC int clblast_get_sub_group_local_id() { - + // Intel extension #if SUBGROUP_SHUFFLING_INTEL == 1 return get_sub_group_local_id(); @@ -36,7 +34,7 @@ INLINE_FUNC int clblast_get_sub_group_local_id() { } INLINE_FUNC realN clblast_sub_group_shuffle(realN reg, int src) { - + // Intel extension #if SUBGROUP_SHUFFLING_INTEL == 1 return intel_sub_group_shuffle(reg, src); @@ -238,48 +236,47 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { #pragma unroll for (int _ki = 0; _ki < KREG/VWN; _ki += 1) { - const int index = _ni * (MWI/VWM) + _mi; #if USE_SUBGROUP_SHUFFLING == 1 const realN aval = clblast_sub_group_shuffle(apm[_ki], _ni); #else const realN aval = apm[_ni * (KREG/VWN) + _ki]; #endif #if VWN == 1 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval); #elif VWN == 2 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); #elif VWN == 4 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w); #elif VWN == 8 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7); #elif VWN == 16 - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE); - cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE); + cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF); #endif } } @@ -311,9 +308,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, } } -// ================================================================================================= - -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part4.opencl b/src/kernels/level3/xgemm_part4.opencl index e581cd84..b1f1ade6 100644 --- a/src/kernels/level3/xgemm_part4.opencl +++ b/src/kernels/level3/xgemm_part4.opencl @@ -15,7 +15,6 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( -// ================================================================================================= // The upper-triangular and lower-triangular kernels are only used in special cases #if defined(ROUTINE_SYRK) || defined(ROUTINE_HERK) || defined(ROUTINE_SYR2K) || defined(ROUTINE_HER2K) @@ -132,9 +131,8 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, } #endif -// ================================================================================================= -// End of the C++11 raw string literal )" +// End of the C++11 raw string literal // ================================================================================================= diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index fd5a20db..cb24460a 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -40,6 +40,7 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" }) { diff --git a/src/routines/level3/xherk.cpp b/src/routines/level3/xherk.cpp index 6912d3a9..2e6f30ec 100644 --- a/src/routines/level3/xherk.cpp +++ b/src/routines/level3/xherk.cpp @@ -32,6 +32,7 @@ Xherk<T,U>::Xherk(Queue &queue, EventPointer event, const std::string &name): , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" }) { diff --git a/src/routines/level3/xsyrk.cpp b/src/routines/level3/xsyrk.cpp index 6bb2a24f..5ffdc028 100644 --- a/src/routines/level3/xsyrk.cpp +++ b/src/routines/level3/xsyrk.cpp @@ -32,6 +32,7 @@ Xsyrk<T>::Xsyrk(Queue &queue, EventPointer event, const std::string &name): , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" }) { diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index 2bbc5007..b12b8734 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -38,6 +38,7 @@ XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::strin , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" , // separated in multiple parts to prevent C1091 in MSVC 2013 diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index 30c161cc..d9e3ebba 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -37,6 +37,7 @@ XgemmStridedBatched<T>::XgemmStridedBatched(Queue &queue, EventPointer event, co , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" , // separated in multiple parts to prevent C1091 in MSVC 2013 diff --git a/src/tuning/kernels/xgemm.hpp b/src/tuning/kernels/xgemm.hpp index 9a538c1b..fa1bb6ec 100644 --- a/src/tuning/kernels/xgemm.hpp +++ b/src/tuning/kernels/xgemm.hpp @@ -50,6 +50,8 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) { settings.sources += #include "../src/kernels/level3/xgemm_part1.opencl" #include "../src/kernels/level3/xgemm_part2.opencl" + ; + settings.sources += #include "../src/kernels/level3/xgemm_part3.opencl" #include "../src/kernels/level3/xgemm_part4.opencl" ; diff --git a/src/utilities/compile.cpp b/src/utilities/compile.cpp index 835f54b4..00cb90cb 100644 --- a/src/utilities/compile.cpp +++ b/src/utilities/compile.cpp @@ -59,7 +59,8 @@ std::shared_ptr<Program> CompileFromSource( } // For Intel GPUs with subgroup support, use subgroup shuffling. - if (device.IsGPU() && device.HasExtension(kKhronosIntelSubgroups)) { + if (device.IsGPU() && device.HasExtension(kKhronosIntelSubgroups) && + (precision == Precision::kSingle || precision == Precision::kHalf)) { header_string += "#define USE_SUBGROUP_SHUFFLING 1\n"; header_string += "#define SUBGROUP_SHUFFLING_INTEL 1\n"; } |