summaryrefslogtreecommitdiff
path: root/src/kernels
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-09-15 21:47:04 +0200
committerGitHub <noreply@github.com>2018-09-15 21:47:04 +0200
commitc163868e1822a97750b4380f0d9cdd38369f9f0b (patch)
treefbdbc5b5697be8cf4237cc11f87d0b1649f4190d /src/kernels
parent91dbd580ab2f5d2363d51ba4e3fc9735f1c7a937 (diff)
parent0f6dd01e513db036191786aed3d03a77e2e8c5dc (diff)
Merge pull request #318 from CNugteren/CLBlast-315-preprocessor-gemmk1-issue
Fixed pre-processor issues with the new GEMMK=1 kernel
Diffstat (limited to 'src/kernels')
-rw-r--r--src/kernels/level3/xgemm_part1.opencl6
-rw-r--r--src/kernels/level3/xgemm_part2.opencl6
-rw-r--r--src/kernels/level3/xgemm_part3.opencl73
-rw-r--r--src/kernels/level3/xgemm_part4.opencl4
4 files changed, 37 insertions, 52 deletions
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
index 3cfc5dfb..80a60107 100644
--- a/src/kernels/level3/xgemm_part1.opencl
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -43,8 +43,6 @@
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
-// =================================================================================================
-
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
// this kernel file is used outside of the CLBlast library.
#ifndef GEMMK
@@ -397,9 +395,7 @@ INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int
}
#endif
-// =================================================================================================
-
-// End of the C++11 raw string literal
)"
+// End of the C++11 raw string literal
// =================================================================================================
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index 17c8955a..ee4d5da5 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -15,8 +15,6 @@
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
-// =================================================================================================
-
// The vectorised multiply-add function
INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bval) {
#if USE_VECTOR_MAD == 1
@@ -171,9 +169,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM c_value, const int _mi,
cgm[index] = result;
}
-// =================================================================================================
-
-// End of the C++11 raw string literal
)"
+// End of the C++11 raw string literal
// =================================================================================================
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index 90de0b3b..77964a94 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -15,14 +15,12 @@
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
-// =================================================================================================
-
// A common interface for subgroup functions
#if USE_SUBGROUP_SHUFFLING == 1
INLINE_FUNC int clblast_get_sub_group_local_id() {
-
+
// Intel extension
#if SUBGROUP_SHUFFLING_INTEL == 1
return get_sub_group_local_id();
@@ -36,7 +34,7 @@ INLINE_FUNC int clblast_get_sub_group_local_id() {
}
INLINE_FUNC realN clblast_sub_group_shuffle(realN reg, int src) {
-
+
// Intel extension
#if SUBGROUP_SHUFFLING_INTEL == 1
return intel_sub_group_shuffle(reg, src);
@@ -238,48 +236,47 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
#pragma unroll
for (int _ki = 0; _ki < KREG/VWN; _ki += 1) {
- const int index = _ni * (MWI/VWM) + _mi;
#if USE_SUBGROUP_SHUFFLING == 1
const realN aval = clblast_sub_group_shuffle(apm[_ki], _ni);
#else
const realN aval = apm[_ni * (KREG/VWN) + _ki];
#endif
#if VWN == 1
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval);
#elif VWN == 2
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
#elif VWN == 4
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w);
#elif VWN == 8
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7);
#elif VWN == 16
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE);
- cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE);
+ cpm[_ni * (MWI/VWM) + _mi] = MultiplyAddVector(cpm[_ni * (MWI/VWM) + _mi], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF);
#endif
}
}
@@ -311,9 +308,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
}
}
-// =================================================================================================
-
-// End of the C++11 raw string literal
)"
+// End of the C++11 raw string literal
// =================================================================================================
diff --git a/src/kernels/level3/xgemm_part4.opencl b/src/kernels/level3/xgemm_part4.opencl
index e581cd84..b1f1ade6 100644
--- a/src/kernels/level3/xgemm_part4.opencl
+++ b/src/kernels/level3/xgemm_part4.opencl
@@ -15,7 +15,6 @@
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
-// =================================================================================================
// The upper-triangular and lower-triangular kernels are only used in special cases
#if defined(ROUTINE_SYRK) || defined(ROUTINE_HERK) || defined(ROUTINE_SYR2K) || defined(ROUTINE_HER2K)
@@ -132,9 +131,8 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
}
#endif
-// =================================================================================================
-// End of the C++11 raw string literal
)"
+// End of the C++11 raw string literal
// =================================================================================================