summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-09 14:09:13 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-09 14:09:13 +0100
commit23e3a85f2c328d4a23db2fca5d1d89d78163711f (patch)
tree02b8dd5364d958184c45c9bfdb2c28e38d72b24e
parentd9df62b7942bb8af5fd385b8545aceb1d8b578f3 (diff)
Reformatted GEMM kernel to support array-to-register promotion
-rw-r--r--src/kernel_preprocessor.cpp1
-rw-r--r--src/kernels/level3/xgemm_batched.opencl14
-rw-r--r--src/kernels/level3/xgemm_part1.opencl182
-rw-r--r--src/kernels/level3/xgemm_part2.opencl48
-rw-r--r--src/kernels/level3/xgemm_part3.opencl143
5 files changed, 183 insertions, 205 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp
index 8738a837..46b6f3df 100644
--- a/src/kernel_preprocessor.cpp
+++ b/src/kernel_preprocessor.cpp
@@ -556,6 +556,7 @@ std::string PreprocessKernelSource(const std::string& kernel_source) {
auto arrays_to_registers = std::unordered_map<std::string, size_t>();
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers);
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false);
+ lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false);
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true);
// Gather the results
diff --git a/src/kernels/level3/xgemm_batched.opencl b/src/kernels/level3/xgemm_batched.opencl
index c7bf10d5..372f910b 100644
--- a/src/kernels/level3/xgemm_batched.opencl
+++ b/src/kernels/level3/xgemm_batched.opencl
@@ -46,20 +46,16 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- realM cpm[NWI][MWI/VWM];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm, blm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, blm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm);
#else
- XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm_, cpm, kSizeM, alpha, beta);
}
// =================================================================================================
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
index 88744668..053eb721 100644
--- a/src/kernels/level3/xgemm_part1.opencl
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -135,50 +135,46 @@ R"(
// =================================================================================================
// Initializes the accumulation registers to zero
-INLINE_FUNC void InitAccRegisters(realM cpm[NWI*MWI/VWM]) {
- #pragma unroll
- for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
- #pragma unroll
- for (int _ni = 0; _ni < NWI; _ni += 1) {
- #if VWM == 1
- SetToZero(cpm[_ni * (MWI/VWM) + _mi]);
- #elif VWM == 2
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].x);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].y);
- #elif VWM == 4
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].x);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].y);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].z);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].w);
- #elif VWM == 8
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7);
- #elif VWM == 16
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s8);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].s9);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].sA);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].sB);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].sC);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].sD);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].sE);
- SetToZero(cpm[_ni * (MWI/VWM) + _mi].sF);
- #endif
- }
- }
+INLINE_FUNC realM InitAccRegisters() {
+ realM result;
+ #if VWM == 1
+ SetToZero(result);
+ #elif VWM == 2
+ SetToZero(result.x);
+ SetToZero(result.y);
+ #elif VWM == 4
+ SetToZero(result.x);
+ SetToZero(result.y);
+ SetToZero(result.z);
+ SetToZero(result.w);
+ #elif VWM == 8
+ SetToZero(result.s0);
+ SetToZero(result.s1);
+ SetToZero(result.s2);
+ SetToZero(result.s3);
+ SetToZero(result.s4);
+ SetToZero(result.s5);
+ SetToZero(result.s6);
+ SetToZero(result.s7);
+ #elif VWM == 16
+ SetToZero(result.s0);
+ SetToZero(result.s1);
+ SetToZero(result.s2);
+ SetToZero(result.s3);
+ SetToZero(result.s4);
+ SetToZero(result.s5);
+ SetToZero(result.s6);
+ SetToZero(result.s7);
+ SetToZero(result.s8);
+ SetToZero(result.s9);
+ SetToZero(result.sA);
+ SetToZero(result.sB);
+ SetToZero(result.sC);
+ SetToZero(result.sD);
+ SetToZero(result.sE);
+ SetToZero(result.sF);
+ #endif
+ return result;
}
// =================================================================================================
@@ -249,47 +245,39 @@ INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR re
// Caches global off-chip memory directly into per-thread private memory (registers). This function
// is specific for caching the A input matrix.
#if SA == 0
-INLINE_FUNC void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/VWM],
- const int kSizeM, const int idk, const int kwg) {
- #pragma unroll
- for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
-
- // Computes the indices based on strided/non-strided access
- #if STRM == 0
- int mg = _mi + get_local_id(0)*(MWI/VWM);
- #elif STRM == 1
- int mg = get_local_id(0) + _mi*MDIMC;
- #endif
-
- // Computes the indices for the global memory
- int idm = mg + GetGroupID0() * (MWG/VWM);
-
- // Loads the data from global memory (not transposed) and stores into registers
- apm[_mi] = agm[idk*(kSizeM/VWM) + idm];
- }
+INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi,
+ const int kSizeM, const int idk, const int kwg) {
+ // Computes the indices based on strided/non-strided access
+ #if STRM == 0
+ int mg = _mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + _mi*MDIMC;
+ #endif
+
+ // Computes the indices for the global memory
+ int idm = mg + GetGroupID0() * (MWG/VWM);
+
+ // Loads the data from global memory (not transposed) and stores into registers
+ return agm[idk*(kSizeM/VWM) + idm];
}
#endif
// Same as above, but now for the B input matrix
#if SB == 0
-INLINE_FUNC void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/VWN],
- const int kSizeN, const int idk) {
- #pragma unroll
- for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
-
- // Computes the indices based on strided/non-strided access
- #if STRN == 0
- int ng = _ni + get_local_id(1)*(NWI/VWN);
- #elif STRN == 1
- int ng = get_local_id(1) + _ni*NDIMC;
- #endif
-
- // Computes the indices for the global memory
- int idn = ng + GetGroupID1() * (NWG/VWN);
-
- // Loads the data from global memory (transposed) and stores into registers
- bpm[_ni] = bgm[idk*(kSizeN/VWN) + idn];
- }
+INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni,
+ const int kSizeN, const int idk) {
+ // Computes the indices based on strided/non-strided access
+ #if STRN == 0
+ int ng = _ni + get_local_id(1)*(NWI/VWN);
+ #elif STRN == 1
+ int ng = get_local_id(1) + _ni*NDIMC;
+ #endif
+
+ // Computes the indices for the global memory
+ int idn = ng + GetGroupID1() * (NWG/VWN);
+
+ // Loads the data from global memory (transposed) and stores into registers
+ return bgm[idk*(kSizeN/VWN) + idn];
}
#endif
@@ -298,31 +286,25 @@ INLINE_FUNC void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[
// Caches on-chip local memory into per-thread private memory (registers). This function is specific
// for caching the A input matrix.
#if SA == 1
-INLINE_FUNC void LocalToPrivateA(LOCAL_PTR realM* alm, realM apm[MWI/VWM], const int kg) {
- #pragma unroll
- for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
- #if STRM == 0
- int mg = _mi + get_local_id(0)*(MWI/VWM);
- #elif STRM == 1
- int mg = get_local_id(0) + _mi*MDIMC;
- #endif
- apm[_mi] = alm[kg*(MWG/VWM) + mg];
- }
+INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR realM* alm, const int _mi, const int kg) {
+ #if STRM == 0
+ int mg = _mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + _mi*MDIMC;
+ #endif
+ return alm[kg*(MWG/VWM) + mg];
}
#endif
// Same as above, but now for the B input matrix
#if SB == 1
-INLINE_FUNC void LocalToPrivateB(LOCAL_PTR realN* blm, realN bpm[NWI/VWN], const int kg) {
- #pragma unroll
- for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
- #if STRN == 0
- int ng = _ni + get_local_id(1)*(NWI/VWN);
- #elif STRN == 1
- int ng = get_local_id(1) + _ni*NDIMC;
- #endif
- bpm[_ni] = blm[kg*(NWG/VWN) + ng];
- }
+INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int kg) {
+ #if STRN == 0
+ int ng = _ni + get_local_id(1)*(NWI/VWN);
+ #elif STRN == 1
+ int ng = get_local_id(1) + _ni*NDIMC;
+ #endif
+ return blm[kg*(NWG/VWN) + ng];
}
#endif
diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl
index 88100e96..14a0493a 100644
--- a/src/kernels/level3/xgemm_part2.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -63,54 +63,6 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva
return cvec;
}
-// Performs the actual computation: Cpm += Apm * Bpm
-INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI*MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) {
- #pragma unroll
- for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
- #pragma unroll
- for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
- const realM aval = apm[_mi];
- #if VWN == 1
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
- #elif VWN == 2
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
- cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
- #elif VWN == 4
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
- cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
- cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
- cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
- #elif VWN == 8
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
- cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
- cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
- cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
- cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
- cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
- cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
- cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
- #elif VWN == 16
- cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
- cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
- cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
- cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
- cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
- cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
- cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
- cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
- cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
- cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
- cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
- cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
- cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
- cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
- cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
- cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
- #endif
- }
- }
-}
-
// =================================================================================================
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index f12fb304..157c1326 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -20,7 +20,7 @@ R"(
// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.
INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
const __global realM* restrict agm, const __global realN* restrict bgm,
- __global realM* cgm, realM cpm[NWI*MWI/VWM]
+ __global realM* cgm, const real alpha, const real beta
#if SA == 1 && SB == 1
, LOCAL_PTR realM* alm, LOCAL_PTR realN* blm
#elif SA == 1
@@ -31,10 +31,12 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
) {
// Allocates workitem-private memory (registers)
- //#pragma promote_to_registers
+ #pragma promote_to_registers
realM apm[MWI/VWM];
- //#pragma promote_to_registers
+ #pragma promote_to_registers
realN bpm[NWI/VWN];
+ #pragma promote_to_registers
+ realM cpm[NWI*(MWI/VWM)];
// Combined thread identifier (volatile to disable caching)
#if SA == 1 || SB == 1
@@ -42,7 +44,14 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Initializes the accumulation registers
- InitAccRegisters(cpm);
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI; _ni += 1) {
+ cpm[_ni * (MWI/VWM) + _mi] = InitAccRegisters();
+ }
+ }
+
// Loops over all workgroup tiles
for (int kwg = 0; kwg < kSizeK; kwg += KWG) {
@@ -70,24 +79,74 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
int kg = pwi + _pit;
#endif
- // Loads data: local --> private (matrix A)
- #if SA == 1
- LocalToPrivateA(alm, apm, kg);
- // Loads data: off-chip --> private (matrix A)
- #else
- GlobalToPrivateA(agm, apm, kSizeM, idk, kwg);
- #endif
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ // Loads data: local --> private (matrix A)
+ #if SA == 1
+ apm[_mi] = LocalToPrivateA(alm, _mi, kg);
+ // Loads data: off-chip --> private (matrix A)
+ #else
+ apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg);
+ #endif
+ }
// Loads data: local --> private (matrix B)
- #if SB == 1
- LocalToPrivateB(blm, bpm, kg);
- // Loads data: off-chip --> private (matrix B)
- #else
- GlobalToPrivateB(bgm, bpm, kSizeN, idk);
- #endif
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ #if SB == 1
+ bpm[_ni] = LocalToPrivateB(blm, _ni, kg);
+ // Loads data: off-chip --> private (matrix B)
+ #else
+ bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk);
+ #endif
+ }
// Performs the accumulation (Cpm += Apm * Bpm)
- MultiplyAccumulate(cpm, apm, bpm);
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ const realM aval = apm[_mi];
+ #if VWN == 1
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
+ #elif VWN == 2
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ #elif VWN == 4
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
+ #elif VWN == 8
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ #elif VWN == 16
+ cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
+ cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
+ cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
+ cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
+ cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
+ cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
+ cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
+ cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
+ #endif
+ }
+ }
+
}
}
#if SA == 1 || SB == 1
@@ -97,6 +156,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#if GLOBAL_MEM_FENCE == 1
barrier(CLK_GLOBAL_MEM_FENCE);
#endif
+
+ // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
// =================================================================================================
@@ -127,21 +189,16 @@ void XgemmUpper(const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- //#pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm);
#else
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeN, alpha, beta);
}
// Main entry point of the kernel. This is the lower-triangular version.
@@ -168,21 +225,16 @@ void XgemmLower(const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- //#pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm);
#else
- XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeN, alpha, beta);
}
// =================================================================================================
@@ -213,21 +265,16 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
__local realN blm[KWG * NWG/VWN];
#endif
- // Computes the matrix-multiplication and stores the result in register memory
- //#pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ // Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm);
#elif SA == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm);
#elif SB == 1
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm);
#else
- XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm);
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta);
#endif
-
- // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
- StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
#endif