summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_part3.opencl
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-09 14:09:13 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-09 14:09:13 +0100
commit23e3a85f2c328d4a23db2fca5d1d89d78163711f (patch)
tree02b8dd5364d958184c45c9bfdb2c28e38d72b24e /src/kernels/level3/xgemm_part3.opencl
parentd9df62b7942bb8af5fd385b8545aceb1d8b578f3 (diff)
Reformatted GEMM kernel to support array-to-register promotion
Diffstat (limited to 'src/kernels/level3/xgemm_part3.opencl')
-rw-r--r--src/kernels/level3/xgemm_part3.opencl143
1 files changed, 95 insertions, 48 deletions
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index f12fb304..157c1326 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -20,7 +20,7 @@ R"(
// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.
INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
const __global realM* restrict agm, const __global realN* restrict bgm,
- __global realM* cgm, realM cpm[NWI*MWI/VWM]
+ __global realM* cgm, const real alpha, const real beta
#if SA == 1 && SB == 1
, LOCAL_PTR realM* alm, LOCAL_PTR realN* blm
#elif SA == 1
@@ -31,10 +31,12 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
) {
// Allocates workitem-private memory (registers)
- //#pragma promote_to_registers
+ #pragma promote_to_registers
realM apm[MWI/VWM];
- //#pragma promote_to_registers
+ #pragma promote_to_registers
realN bpm[NWI/VWN];
+ #pragma promote_to_registers
+ realM cpm[NWI*(MWI/VWM)];
// Combined thread identifier (volatile to disable caching)
#if SA == 1 || SB == 1
@@ -42,7 +44,14 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Initializes the accumulation registers
- InitAccRegisters(cpm);
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI; _ni += 1) {
+ cpm[_ni * (MWI/VWM) + _mi] = InitAccRegisters();
+ }
+ }
+
// Loops over all workgroup tiles
for (int kwg = 0; kwg < kSizeK; kwg += KWG) {
@@ -70,24 +79,74 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
int kg = pwi + _pit;
#endif
- // Loads data: local --> private (matrix A)
- #if SA == 1
- LocalToPrivateA(alm, apm, kg);
- // Loads data: off-chip --> private (matrix A)
- #else
- GlobalToPrivateA(agm, apm, kSizeM, idk, kwg);
- #endif
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ // Loads data: local --> private (matrix A)
+ #if SA == 1
+ apm[_mi] = LocalToPrivateA(alm, _mi, kg);
+ // Loads data: off-chip --> private (matrix A)
+ #else
+ apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg);
+ #endif
+ }
// Loads data: local --> private (matrix B)
- #if SB == 1
- LocalToPrivateB(blm, bpm, kg);
- // Loads data: off-chip --> private (matrix B)
- #else
- GlobalToPrivateB(bgm, bpm, kSizeN, idk);
- #endif
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ #if SB == 1
+ bpm[_ni] = LocalToPrivateB(blm, _ni, kg);
+ // Loads data: off-chip --> private (matrix B)
+ #else
+ bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk);
+ #endif
+ }
// Performs the accumulation (Cpm += Apm * Bpm)
- MultiplyAccumulate(cpm, apm, bpm);
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ const realM aval = apm[_mi];
+ #if VWN == 1
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
+ #elif VWN == 2
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ #elif VWN == 4
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
+ #elif VWN == 8
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ #elif VWN == 16
+ cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
+ cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
+ cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
+ cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
+ cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
+ cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
+ cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
+ cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
+ #endif
+ }
+ }
+
}
}
#if SA == 1 || SB == 1
@@ -97,6 +156,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#if GLOBAL_MEM_FENCE == 1
barrier(CLK_GLOBAL_MEM_FENCE);
#endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
// =================================================================================================
@@ -127,21 +189,16 @@ void XgemmUpper(const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- //#pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm);
#else
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeN, alpha, beta);
}
// Main entry point of the kernel. This is the lower-triangular version.
@@ -168,21 +225,16 @@ void XgemmLower(const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- //#pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm);
#else
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeN, alpha, beta);
}
// =================================================================================================
@@ -213,21 +265,16 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- //#pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm);
#else
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
#endif