summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/kernel_preprocessor.cpp4
-rw-r--r--src/kernels/level3/xgemm_part1.opencl64
-rw-r--r--src/kernels/level3/xgemm_part2.opencl68
-rw-r--r--src/kernels/level3/xgemm_part3.opencl13
4 files changed, 77 insertions, 72 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp
index 6239361b..493c009c 100644
--- a/src/kernel_preprocessor.cpp
+++ b/src/kernel_preprocessor.cpp
@@ -50,8 +50,8 @@ bool HasOnlyDigits(const std::string& str) {
int ParseMath(const std::string& str) {
// Handles brackets
- const auto split_close = split(str, ')');
- if (split_close.size() >= 2) {
+ if (str.find(")") != std::string::npos) {
+ const auto split_close = split(str, ')');
const auto split_end = split(split_close[0], '(');
if (split_end.size() < 2) { RaiseError(str, "Mismatching brackets #0"); }
const auto bracket_contents = ParseMath(split_end[split_end.size() - 1]);
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
index e118ba2f..88744668 100644
--- a/src/kernels/level3/xgemm_part1.opencl
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -135,47 +135,47 @@ R"(
// =================================================================================================
// Initializes the accumulation registers to zero
-INLINE_FUNC void InitAccRegisters(realM cpm[NWI][MWI/VWM]) {
+INLINE_FUNC void InitAccRegisters(realM cpm[NWI*MWI/VWM]) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
#if VWM == 1
- SetToZero(cpm[_ni][_mi]);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi]);
#elif VWM == 2
- SetToZero(cpm[_ni][_mi].x);
- SetToZero(cpm[_ni][_mi].y);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].x);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].y);
#elif VWM == 4
- SetToZero(cpm[_ni][_mi].x);
- SetToZero(cpm[_ni][_mi].y);
- SetToZero(cpm[_ni][_mi].z);
- SetToZero(cpm[_ni][_mi].w);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].x);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].y);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].z);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].w);
#elif VWM == 8
- SetToZero(cpm[_ni][_mi].s0);
- SetToZero(cpm[_ni][_mi].s1);
- SetToZero(cpm[_ni][_mi].s2);
- SetToZero(cpm[_ni][_mi].s3);
- SetToZero(cpm[_ni][_mi].s4);
- SetToZero(cpm[_ni][_mi].s5);
- SetToZero(cpm[_ni][_mi].s6);
- SetToZero(cpm[_ni][_mi].s7);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7);
#elif VWM == 16
- SetToZero(cpm[_ni][_mi].s0);
- SetToZero(cpm[_ni][_mi].s1);
- SetToZero(cpm[_ni][_mi].s2);
- SetToZero(cpm[_ni][_mi].s3);
- SetToZero(cpm[_ni][_mi].s4);
- SetToZero(cpm[_ni][_mi].s5);
- SetToZero(cpm[_ni][_mi].s6);
- SetToZero(cpm[_ni][_mi].s7);
- SetToZero(cpm[_ni][_mi].s8);
- SetToZero(cpm[_ni][_mi].s9);
- SetToZero(cpm[_ni][_mi].sA);
- SetToZero(cpm[_ni][_mi].sB);
- SetToZero(cpm[_ni][_mi].sC);
- SetToZero(cpm[_ni][_mi].sD);
- SetToZero(cpm[_ni][_mi].sE);
- SetToZero(cpm[_ni][_mi].sF);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s8);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].s9);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].sA);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].sB);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].sC);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].sD);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].sE);
+ SetToZero(cpm[_ni * (MWI/VWM) + _mi].sF);
#endif
}
}
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index a5507458..88100e96 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -64,48 +64,48 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva
}
// Performs the actual computation: Cpm += Apm * Bpm
-INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
+INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI*MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const realM aval = apm[_mi];
#if VWN == 1
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni]);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
#elif VWN == 2
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x);
- cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
#elif VWN == 4
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].x);
- cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].y);
- cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].z);
- cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].w);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
#elif VWN == 8
- cpm[_ni*VWN + 0][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0][_mi], aval, bpm[_ni].s0);
- cpm[_ni*VWN + 1][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1][_mi], aval, bpm[_ni].s1);
- cpm[_ni*VWN + 2][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2][_mi], aval, bpm[_ni].s2);
- cpm[_ni*VWN + 3][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3][_mi], aval, bpm[_ni].s3);
- cpm[_ni*VWN + 4][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4][_mi], aval, bpm[_ni].s4);
- cpm[_ni*VWN + 5][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5][_mi], aval, bpm[_ni].s5);
- cpm[_ni*VWN + 6][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6][_mi], aval, bpm[_ni].s6);
- cpm[_ni*VWN + 7][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7][_mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
#elif VWN == 16
- cpm[_ni*VWN + 0 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 0 ][_mi], aval, bpm[_ni].s0);
- cpm[_ni*VWN + 1 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 1 ][_mi], aval, bpm[_ni].s1);
- cpm[_ni*VWN + 2 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 2 ][_mi], aval, bpm[_ni].s2);
- cpm[_ni*VWN + 3 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 3 ][_mi], aval, bpm[_ni].s3);
- cpm[_ni*VWN + 4 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 4 ][_mi], aval, bpm[_ni].s4);
- cpm[_ni*VWN + 5 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 5 ][_mi], aval, bpm[_ni].s5);
- cpm[_ni*VWN + 6 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 6 ][_mi], aval, bpm[_ni].s6);
- cpm[_ni*VWN + 7 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 7 ][_mi], aval, bpm[_ni].s7);
- cpm[_ni*VWN + 8 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 8 ][_mi], aval, bpm[_ni].s8);
- cpm[_ni*VWN + 9 ][_mi] = MultiplyAddVector(cpm[_ni*VWN + 9 ][_mi], aval, bpm[_ni].s9);
- cpm[_ni*VWN + 10][_mi] = MultiplyAddVector(cpm[_ni*VWN + 10][_mi], aval, bpm[_ni].sA);
- cpm[_ni*VWN + 11][_mi] = MultiplyAddVector(cpm[_ni*VWN + 11][_mi], aval, bpm[_ni].sB);
- cpm[_ni*VWN + 12][_mi] = MultiplyAddVector(cpm[_ni*VWN + 12][_mi], aval, bpm[_ni].sC);
- cpm[_ni*VWN + 13][_mi] = MultiplyAddVector(cpm[_ni*VWN + 13][_mi], aval, bpm[_ni].sD);
- cpm[_ni*VWN + 14][_mi] = MultiplyAddVector(cpm[_ni*VWN + 14][_mi], aval, bpm[_ni].sE);
- cpm[_ni*VWN + 15][_mi] = MultiplyAddVector(cpm[_ni*VWN + 15][_mi], aval, bpm[_ni].sF);
+ cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
+ cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
+ cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
+ cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
+ cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
+ cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
+ cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
+ cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
#endif
}
}
@@ -115,7 +115,7 @@ INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM],
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
-INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
+INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI*MWI/VWM], const int kSizeM,
const real alpha, const real beta) {
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
@@ -136,7 +136,7 @@ INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], cons
int index = idn*(kSizeM/VWM) + idm;
realM result;
- realM xval = cpm[_ni][_mi];
+ realM xval = cpm[_ni * (MWI/VWM) + _mi];
// The final multiplication with alpha (in case beta == 0)
if (IsZero(beta)) {
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index 4e85c4a8..7e46cef5 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -20,7 +20,7 @@ R"(
// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.
INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
const __global realM* restrict agm, const __global realN* restrict bgm,
- __global realM* cgm, realM cpm[NWI][MWI/VWM]
+ __global realM* cgm, realM cpm[NWI*MWI/VWM]
#if SA == 1 && SB == 1
, LOCAL_PTR realM* alm, LOCAL_PTR realN* blm
#elif SA == 1
@@ -31,7 +31,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
) {
// Allocates workitem-private memory (registers)
+ #pragma promote_to_registers
realM apm[MWI/VWM];
+ #pragma promote_to_registers
realN bpm[NWI/VWN];
// Combined thread identifier (volatile to disable caching)
@@ -126,7 +128,8 @@ void XgemmUpper(const int kSizeN, const int kSizeK,
#endif
// Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
+ #pragma promote_to_registers
+ realM cpm[NWI*(MWI/VWM)];
#if SA == 1 && SB == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
#elif SA == 1
@@ -166,7 +169,8 @@ void XgemmLower(const int kSizeN, const int kSizeK,
#endif
// Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
+ #pragma promote_to_registers
+ realM cpm[NWI*(MWI/VWM)];
#if SA == 1 && SB == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
#elif SA == 1
@@ -210,7 +214,8 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
+ #pragma promote_to_registers
+ realM cpm[NWI*(MWI/VWM)];
#if SA == 1 && SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
#elif SA == 1