summaryrefslogtreecommitdiff
path: root/src/kernels
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-03-23 20:29:20 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-03-23 20:29:20 +0100
commit1cbe2ea301c6b28a7d1101142ff347471f7dc197 (patch)
treee4c9b4f8072daebe45e6e1bc5059cf7a798eb9d9 /src/kernels
parent52791bf3553bb47a50dea4ac234f7e1b09c4383c (diff)
Removed arrays as function argument from GEMM kernels for Vivante OpenCL compiler
Diffstat (limited to 'src/kernels')
-rw-r--r--src/kernels/level3/xgemm_direct_part1.opencl75
-rw-r--r--src/kernels/level3/xgemm_direct_part3.opencl18
-rw-r--r--src/kernels/level3/xgemm_part2.opencl202
-rw-r--r--src/kernels/level3/xgemm_part3.opencl8
4 files changed, 153 insertions, 150 deletions
diff --git a/src/kernels/level3/xgemm_direct_part1.opencl b/src/kernels/level3/xgemm_direct_part1.opencl
index 38aa31fb..8ca2ceb4 100644
--- a/src/kernels/level3/xgemm_direct_part1.opencl
+++ b/src/kernels/level3/xgemm_direct_part1.opencl
@@ -171,59 +171,48 @@ INLINE_FUNC real LocalToPrivateDirectB(LOCAL_PTR real* blm, const int _ni, const
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
-INLINE_FUNC void StoreResultsDirect(__global real* cgm, real cpd[NWID * MWID],
- const int idm, const int idn,
+INLINE_FUNC void StoreResultsDirect(__global real* cgm, const real c_value,
+ const int _mi, const int _ni, const int idm, const int idn,
const real alpha, const real beta,
const int c_ld, const int c_offset, const int c_transpose) {
- #pragma unroll
- for (int _ni = 0; _ni < NWID; _ni += 1) {
- #pragma unroll
- for (int _mi = 0; _mi < MWID; _mi += 1) {
-
- // Deter_mines the destination index
- int c_index = (c_transpose) ? (idm + _mi)*c_ld + (idn + _ni) : (idn + _ni)*c_ld + (idm + _mi);
-
- // The final multiplication with alpha (in case beta == 0)
- real result;
- if (IsZero(beta)) {
- Multiply(result, alpha, cpd[_ni * MWID + _mi]);
- }
- // The final multiplication with alpha and the addition with beta*C
- else {
- AXPBY(result, alpha, cpd[_ni * MWID + _mi], beta, cgm[c_index + c_offset]);
- }
- cgm[c_index + c_offset] = result;
- }
+
+ // Determines the destination index
+ int c_index = (c_transpose) ? (idm + _mi)*c_ld + (idn + _ni) : (idn + _ni)*c_ld + (idm + _mi);
+
+ // The final multiplication with alpha (in case beta == 0)
+ real result;
+ if (IsZero(beta)) {
+ Multiply(result, alpha, c_value);
}
+ // The final multiplication with alpha and the addition with beta*C
+ else {
+ AXPBY(result, alpha, c_value, beta, cgm[c_index + c_offset]);
+ }
+ cgm[c_index + c_offset] = result;
}
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
-INLINE_FUNC void StoreResultsChecked(__global real* cgm, real cpd[NWID * MWID],
- const int idm, const int idn, const int kSizeM, const int kSizeN,
+INLINE_FUNC void StoreResultsChecked(__global real* cgm, const real c_value,
+ const int _mi, const int _ni, const int idm, const int idn,
+ const int kSizeM, const int kSizeN,
const real alpha, const real beta,
const int c_ld, const int c_offset, const int c_transpose) {
- #pragma unroll
- for (int _ni = 0; _ni < NWID; _ni += 1) {
- #pragma unroll
- for (int _mi = 0; _mi < MWID; _mi += 1) {
- if ((idm + _mi) < kSizeM && (idn + _ni) < kSizeN) {
-
- // Deter_mines the destination index
- int c_index = (c_transpose) ? (idm + _mi)*c_ld + (idn + _ni) : (idn + _ni)*c_ld + (idm + _mi);
-
- // The final multiplication with alpha (in case beta == 0)
- real result;
- if (IsZero(beta)) {
- Multiply(result, alpha, cpd[_ni * MWID + _mi]);
- }
- // The final multiplication with alpha and the addition with beta*C
- else {
- AXPBY(result, alpha, cpd[_ni * MWID + _mi], beta, cgm[c_index + c_offset]);
- }
- cgm[c_index + c_offset] = result;
- }
+ if ((idm + _mi) < kSizeM && (idn + _ni) < kSizeN) {
+
+ // Deter_mines the destination index
+ int c_index = (c_transpose) ? (idm + _mi)*c_ld + (idn + _ni) : (idn + _ni)*c_ld + (idm + _mi);
+
+ // The final multiplication with alpha (in case beta == 0)
+ real result;
+ if (IsZero(beta)) {
+ Multiply(result, alpha, c_value);
+ }
+ // The final multiplication with alpha and the addition with beta*C
+ else {
+ AXPBY(result, alpha, c_value, beta, cgm[c_index + c_offset]);
}
+ cgm[c_index + c_offset] = result;
}
}
diff --git a/src/kernels/level3/xgemm_direct_part3.opencl b/src/kernels/level3/xgemm_direct_part3.opencl
index e1532e98..0822c95f 100644
--- a/src/kernels/level3/xgemm_direct_part3.opencl
+++ b/src/kernels/level3/xgemm_direct_part3.opencl
@@ -129,7 +129,14 @@ INLINE_FUNC void XgemmDirect(const int kSizeM, const int kSizeN, const int kSize
}
// Stores a tile of results and performs the multiplication with alpha and beta
- StoreResultsDirect(cgm, cpd, idm, idn, alpha, beta, c_ld, c_offset, c_transpose);
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ StoreResultsDirect(cgm, cpd[_ni * MWID + _mi], _mi, _ni, idm, idn,
+ alpha, beta, c_ld, c_offset, c_transpose);
+ }
+ }
}
// Simple but slower version for the parts on the edge (incomplete tiles in M and N-dimensions)
@@ -197,7 +204,14 @@ INLINE_FUNC void XgemmDirect(const int kSizeM, const int kSizeN, const int kSize
}
// Stores a tile of results and performs the multiplication with alpha and beta
- StoreResultsChecked(cgm, cpd, idm, idn, kSizeM, kSizeN, alpha, beta, c_ld, c_offset, c_transpose);
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ StoreResultsChecked(cgm, cpd[_ni * MWID + _mi], _mi, _ni, idm, idn, kSizeM, kSizeN,
+ alpha, beta, c_ld, c_offset, c_transpose);
+ }
+ }
}
}
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index 1c7a940b..17c8955a 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -67,114 +67,108 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
-INLINE_FUNC void StoreResults(__global realM* cgm, realM cpm[NWI*MWI/VWM], const int kSizeM,
- const real alpha, const real beta) {
- #pragma unroll
- for (int _ni = 0; _ni < NWI; _ni += 1) {
- #pragma unroll
- for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
- #if STRM == 0
- int mg = _mi + get_local_id(0)*(MWI/VWM);
- #elif STRM == 1
- int mg = get_local_id(0) + _mi*MDIMC;
- #endif
- #if STRN == 0
- int ng = _ni + get_local_id(1)*NWI;
- #elif STRN == 1
- int ng = _ni%VWN + get_local_id(1)*VWN + (_ni/VWN)*VWN*NDIMC;
- #endif
- int idm = mg + GetGroupID0() * (MWG/VWM);
- int idn = ng + GetGroupID1() * NWG;
- int index = idn*(kSizeM/VWM) + idm;
+INLINE_FUNC void StoreResults(__global realM* cgm, realM c_value, const int _mi, const int _ni,
+ const int kSizeM, const real alpha, const real beta) {
+ #if STRM == 0
+ int mg = _mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + _mi*MDIMC;
+ #endif
+ #if STRN == 0
+ int ng = _ni + get_local_id(1)*NWI;
+ #elif STRN == 1
+ int ng = _ni%VWN + get_local_id(1)*VWN + (_ni/VWN)*VWN*NDIMC;
+ #endif
+ int idm = mg + GetGroupID0() * (MWG/VWM);
+ int idn = ng + GetGroupID1() * NWG;
+ int index = idn*(kSizeM/VWM) + idm;
- realM result;
- realM xval = cpm[_ni * (MWI/VWM) + _mi];
+ realM result;
+ realM xval = c_value;
- // The final multiplication with alpha (in case beta == 0)
- if (IsZero(beta)) {
- #if VWM == 1
- Multiply(result, alpha, xval);
- #elif VWM == 2
- Multiply(result.x, alpha, xval.x);
- Multiply(result.y, alpha, xval.y);
- #elif VWM == 4
- Multiply(result.x, alpha, xval.x);
- Multiply(result.y, alpha, xval.y);
- Multiply(result.z, alpha, xval.z);
- Multiply(result.w, alpha, xval.w);
- #elif VWM == 8
- Multiply(result.s0, alpha, xval.s0);
- Multiply(result.s1, alpha, xval.s1);
- Multiply(result.s2, alpha, xval.s2);
- Multiply(result.s3, alpha, xval.s3);
- Multiply(result.s4, alpha, xval.s4);
- Multiply(result.s5, alpha, xval.s5);
- Multiply(result.s6, alpha, xval.s6);
- Multiply(result.s7, alpha, xval.s7);
- #elif VWM == 16
- Multiply(result.s0, alpha, xval.s0);
- Multiply(result.s1, alpha, xval.s1);
- Multiply(result.s2, alpha, xval.s2);
- Multiply(result.s3, alpha, xval.s3);
- Multiply(result.s4, alpha, xval.s4);
- Multiply(result.s5, alpha, xval.s5);
- Multiply(result.s6, alpha, xval.s6);
- Multiply(result.s7, alpha, xval.s7);
- Multiply(result.s8, alpha, xval.s8);
- Multiply(result.s9, alpha, xval.s9);
- Multiply(result.sA, alpha, xval.sA);
- Multiply(result.sB, alpha, xval.sB);
- Multiply(result.sC, alpha, xval.sC);
- Multiply(result.sD, alpha, xval.sD);
- Multiply(result.sE, alpha, xval.sE);
- Multiply(result.sF, alpha, xval.sF);
- #endif
- }
+ // The final multiplication with alpha (in case beta == 0)
+ if (IsZero(beta)) {
+ #if VWM == 1
+ Multiply(result, alpha, xval);
+ #elif VWM == 2
+ Multiply(result.x, alpha, xval.x);
+ Multiply(result.y, alpha, xval.y);
+ #elif VWM == 4
+ Multiply(result.x, alpha, xval.x);
+ Multiply(result.y, alpha, xval.y);
+ Multiply(result.z, alpha, xval.z);
+ Multiply(result.w, alpha, xval.w);
+ #elif VWM == 8
+ Multiply(result.s0, alpha, xval.s0);
+ Multiply(result.s1, alpha, xval.s1);
+ Multiply(result.s2, alpha, xval.s2);
+ Multiply(result.s3, alpha, xval.s3);
+ Multiply(result.s4, alpha, xval.s4);
+ Multiply(result.s5, alpha, xval.s5);
+ Multiply(result.s6, alpha, xval.s6);
+ Multiply(result.s7, alpha, xval.s7);
+ #elif VWM == 16
+ Multiply(result.s0, alpha, xval.s0);
+ Multiply(result.s1, alpha, xval.s1);
+ Multiply(result.s2, alpha, xval.s2);
+ Multiply(result.s3, alpha, xval.s3);
+ Multiply(result.s4, alpha, xval.s4);
+ Multiply(result.s5, alpha, xval.s5);
+ Multiply(result.s6, alpha, xval.s6);
+ Multiply(result.s7, alpha, xval.s7);
+ Multiply(result.s8, alpha, xval.s8);
+ Multiply(result.s9, alpha, xval.s9);
+ Multiply(result.sA, alpha, xval.sA);
+ Multiply(result.sB, alpha, xval.sB);
+ Multiply(result.sC, alpha, xval.sC);
+ Multiply(result.sD, alpha, xval.sD);
+ Multiply(result.sE, alpha, xval.sE);
+ Multiply(result.sF, alpha, xval.sF);
+ #endif
+ }
- // The final multiplication with alpha and the addition with beta*C
- else {
- realM yval = cgm[index];
- #if VWM == 1
- AXPBY(result, alpha, xval, beta, yval);
- #elif VWM == 2
- AXPBY(result.x, alpha, xval.x, beta, yval.x);
- AXPBY(result.y, alpha, xval.y, beta, yval.y);
- #elif VWM == 4
- AXPBY(result.x, alpha, xval.x, beta, yval.x);
- AXPBY(result.y, alpha, xval.y, beta, yval.y);
- AXPBY(result.z, alpha, xval.z, beta, yval.z);
- AXPBY(result.w, alpha, xval.w, beta, yval.w);
- #elif VWM == 8
- AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
- AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
- AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
- AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
- AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
- AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
- AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
- AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
- #elif VWM == 16
- AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
- AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
- AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
- AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
- AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
- AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
- AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
- AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
- AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
- AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
- AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
- AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
- AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
- AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
- AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
- AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
- #endif
- }
- cgm[index] = result;
- }
+ // The final multiplication with alpha and the addition with beta*C
+ else {
+ realM yval = cgm[index];
+ #if VWM == 1
+ AXPBY(result, alpha, xval, beta, yval);
+ #elif VWM == 2
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ #elif VWM == 4
+ AXPBY(result.x, alpha, xval.x, beta, yval.x);
+ AXPBY(result.y, alpha, xval.y, beta, yval.y);
+ AXPBY(result.z, alpha, xval.z, beta, yval.z);
+ AXPBY(result.w, alpha, xval.w, beta, yval.w);
+ #elif VWM == 8
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ #elif VWM == 16
+ AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
+ AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
+ AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
+ AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
+ AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
+ AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
+ AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
+ AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
+ AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
+ AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
+ AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
+ AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
+ AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
+ AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
+ AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
+ AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
+ #endif
}
+ cgm[index] = result;
}
// =================================================================================================
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index 1483c26e..08778f0d 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -158,7 +158,13 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeM, alpha, beta);
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, kSizeM, alpha, beta);
+ }
+ }
}
// =================================================================================================