diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-09 14:09:13 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-09 14:09:13 +0100 |
commit | 23e3a85f2c328d4a23db2fca5d1d89d78163711f (patch) | |
tree | 02b8dd5364d958184c45c9bfdb2c28e38d72b24e | |
parent | d9df62b7942bb8af5fd385b8545aceb1d8b578f3 (diff) |
Reformatted GEMM kernel to support array-to-register promotion
-rw-r--r-- | src/kernel_preprocessor.cpp | 1 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_batched.opencl | 14 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part1.opencl | 182 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part2.opencl | 48 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_part3.opencl | 143 |
5 files changed, 183 insertions, 205 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index 8738a837..46b6f3df 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -556,6 +556,7 @@ std::string PreprocessKernelSource(const std::string& kernel_source) { auto arrays_to_registers = std::unordered_map<std::string, size_t>(); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); + lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true); // Gather the results diff --git a/src/kernels/level3/xgemm_batched.opencl b/src/kernels/level3/xgemm_batched.opencl index c7bf10d5..372f910b 100644 --- a/src/kernels/level3/xgemm_batched.opencl +++ b/src/kernels/level3/xgemm_batched.opencl @@ -46,20 +46,16 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - realM cpm[NWI][MWI/VWM]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm); #else - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm_, cpm, kSizeM, alpha, beta); } // ================================================================================================= diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl index 88744668..053eb721 100644 --- a/src/kernels/level3/xgemm_part1.opencl +++ b/src/kernels/level3/xgemm_part1.opencl @@ -135,50 +135,46 @@ R"( // ================================================================================================= // Initializes the accumulation registers to zero -INLINE_FUNC void InitAccRegisters(realM cpm[NWI*MWI/VWM]) { - #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - #pragma unroll - for (int _ni = 0; _ni < NWI; _ni += 1) { - #if VWM == 1 - SetToZero(cpm[_ni * (MWI/VWM) + _mi]); - #elif VWM == 2 - SetToZero(cpm[_ni * (MWI/VWM) + _mi].x); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].y); - #elif VWM == 4 - SetToZero(cpm[_ni * (MWI/VWM) + _mi].x); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].y); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].z); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].w); - #elif VWM == 8 - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7); - #elif VWM == 16 - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s0); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s1); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s2); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s3); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s4); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s5); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s6); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s7); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s8); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].s9); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].sA); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].sB); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].sC); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].sD); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].sE); - SetToZero(cpm[_ni * (MWI/VWM) + _mi].sF); - #endif - } - } +INLINE_FUNC realM InitAccRegisters() { + realM result; + #if VWM == 1 + SetToZero(result); + #elif VWM == 2 + SetToZero(result.x); + SetToZero(result.y); + #elif VWM == 4 + SetToZero(result.x); + SetToZero(result.y); + SetToZero(result.z); + SetToZero(result.w); + #elif VWM == 8 + SetToZero(result.s0); + SetToZero(result.s1); + SetToZero(result.s2); + SetToZero(result.s3); + SetToZero(result.s4); + SetToZero(result.s5); + SetToZero(result.s6); + SetToZero(result.s7); + #elif VWM == 16 + SetToZero(result.s0); + SetToZero(result.s1); + SetToZero(result.s2); + SetToZero(result.s3); + SetToZero(result.s4); + SetToZero(result.s5); + SetToZero(result.s6); + SetToZero(result.s7); + SetToZero(result.s8); + SetToZero(result.s9); + SetToZero(result.sA); + SetToZero(result.sB); + SetToZero(result.sC); + SetToZero(result.sD); + SetToZero(result.sE); + SetToZero(result.sF); + #endif + return result; } // ================================================================================================= @@ -249,47 +245,39 @@ INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR re // Caches global off-chip memory directly into per-thread private memory (registers). This function // is specific for caching the A input matrix. #if SA == 0 -INLINE_FUNC void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/VWM], - const int kSizeM, const int idk, const int kwg) { - #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - - // Computes the indices based on strided/non-strided access - #if STRM == 0 - int mg = _mi + get_local_id(0)*(MWI/VWM); - #elif STRM == 1 - int mg = get_local_id(0) + _mi*MDIMC; - #endif - - // Computes the indices for the global memory - int idm = mg + GetGroupID0() * (MWG/VWM); - - // Loads the data from global memory (not transposed) and stores into registers - apm[_mi] = agm[idk*(kSizeM/VWM) + idm]; - } +INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi, + const int kSizeM, const int idk, const int kwg) { + // Computes the indices based on strided/non-strided access + #if STRM == 0 + int mg = _mi + get_local_id(0)*(MWI/VWM); + #elif STRM == 1 + int mg = get_local_id(0) + _mi*MDIMC; + #endif + + // Computes the indices for the global memory + int idm = mg + GetGroupID0() * (MWG/VWM); + + // Loads the data from global memory (not transposed) and stores into registers + return agm[idk*(kSizeM/VWM) + idm]; } #endif // Same as above, but now for the B input matrix #if SB == 0 -INLINE_FUNC void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/VWN], - const int kSizeN, const int idk) { - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - - // Computes the indices based on strided/non-strided access - #if STRN == 0 - int ng = _ni + get_local_id(1)*(NWI/VWN); - #elif STRN == 1 - int ng = get_local_id(1) + _ni*NDIMC; - #endif - - // Computes the indices for the global memory - int idn = ng + GetGroupID1() * (NWG/VWN); - - // Loads the data from global memory (transposed) and stores into registers - bpm[_ni] = bgm[idk*(kSizeN/VWN) + idn]; - } +INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni, + const int kSizeN, const int idk) { + // Computes the indices based on strided/non-strided access + #if STRN == 0 + int ng = _ni + get_local_id(1)*(NWI/VWN); + #elif STRN == 1 + int ng = get_local_id(1) + _ni*NDIMC; + #endif + + // Computes the indices for the global memory + int idn = ng + GetGroupID1() * (NWG/VWN); + + // Loads the data from global memory (transposed) and stores into registers + return bgm[idk*(kSizeN/VWN) + idn]; } #endif @@ -298,31 +286,25 @@ INLINE_FUNC void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[ // Caches on-chip local memory into per-thread private memory (registers). This function is specific // for caching the A input matrix. #if SA == 1 -INLINE_FUNC void LocalToPrivateA(LOCAL_PTR realM* alm, realM apm[MWI/VWM], const int kg) { - #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - #if STRM == 0 - int mg = _mi + get_local_id(0)*(MWI/VWM); - #elif STRM == 1 - int mg = get_local_id(0) + _mi*MDIMC; - #endif - apm[_mi] = alm[kg*(MWG/VWM) + mg]; - } +INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR realM* alm, const int _mi, const int kg) { + #if STRM == 0 + int mg = _mi + get_local_id(0)*(MWI/VWM); + #elif STRM == 1 + int mg = get_local_id(0) + _mi*MDIMC; + #endif + return alm[kg*(MWG/VWM) + mg]; } #endif // Same as above, but now for the B input matrix #if SB == 1 -INLINE_FUNC void LocalToPrivateB(LOCAL_PTR realN* blm, realN bpm[NWI/VWN], const int kg) { - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - #if STRN == 0 - int ng = _ni + get_local_id(1)*(NWI/VWN); - #elif STRN == 1 - int ng = get_local_id(1) + _ni*NDIMC; - #endif - bpm[_ni] = blm[kg*(NWG/VWN) + ng]; - } +INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int kg) { + #if STRN == 0 + int ng = _ni + get_local_id(1)*(NWI/VWN); + #elif STRN == 1 + int ng = get_local_id(1) + _ni*NDIMC; + #endif + return blm[kg*(NWG/VWN) + ng]; } #endif diff --git a/src/kernels/level3/xgemm_part2.opencl b/src/kernels/level3/xgemm_part2.opencl index 88100e96..14a0493a 100644 --- a/src/kernels/level3/xgemm_part2.opencl +++ b/src/kernels/level3/xgemm_part2.opencl @@ -63,54 +63,6 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva return cvec; } -// Performs the actual computation: Cpm += Apm * Bpm -INLINE_FUNC void MultiplyAccumulate(realM cpm[NWI*MWI/VWM], realM apm[MWI/VWM], realN bpm[NWI/VWN]) { - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - const realM aval = apm[_mi]; - #if VWN == 1 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); - #elif VWN == 2 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - #elif VWN == 4 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); - #elif VWN == 8 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - #elif VWN == 16 - cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); - cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); - cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); - cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); - cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); - cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); - cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); - cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); - #endif - } - } -} - // ================================================================================================= // Merges the results in Cpm with the global array in Cgm. This also performs the multiplication diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index f12fb304..157c1326 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -20,7 +20,7 @@ R"( // Main body of the matrix-multiplication algorithm. It calls various (inlined) functions. INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, const __global realM* restrict agm, const __global realN* restrict bgm, - __global realM* cgm, realM cpm[NWI*MWI/VWM] + __global realM* cgm, const real alpha, const real beta #if SA == 1 && SB == 1 , LOCAL_PTR realM* alm, LOCAL_PTR realN* blm #elif SA == 1 @@ -31,10 +31,12 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, ) { // Allocates workitem-private memory (registers) - //#pragma promote_to_registers + #pragma promote_to_registers realM apm[MWI/VWM]; - //#pragma promote_to_registers + #pragma promote_to_registers realN bpm[NWI/VWN]; + #pragma promote_to_registers + realM cpm[NWI*(MWI/VWM)]; // Combined thread identifier (volatile to disable caching) #if SA == 1 || SB == 1 @@ -42,7 +44,14 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Initializes the accumulation registers - InitAccRegisters(cpm); + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + #pragma unroll + for (int _ni = 0; _ni < NWI; _ni += 1) { + cpm[_ni * (MWI/VWM) + _mi] = InitAccRegisters(); + } + } + // Loops over all workgroup tiles for (int kwg = 0; kwg < kSizeK; kwg += KWG) { @@ -70,24 +79,74 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, int kg = pwi + _pit; #endif - // Loads data: local --> private (matrix A) - #if SA == 1 - LocalToPrivateA(alm, apm, kg); - // Loads data: off-chip --> private (matrix A) - #else - GlobalToPrivateA(agm, apm, kSizeM, idk, kwg); - #endif + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + // Loads data: local --> private (matrix A) + #if SA == 1 + apm[_mi] = LocalToPrivateA(alm, _mi, kg); + // Loads data: off-chip --> private (matrix A) + #else + apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg); + #endif + } // Loads data: local --> private (matrix B) - #if SB == 1 - LocalToPrivateB(blm, bpm, kg); - // Loads data: off-chip --> private (matrix B) - #else - GlobalToPrivateB(bgm, bpm, kSizeN, idk); - #endif + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #if SB == 1 + bpm[_ni] = LocalToPrivateB(blm, _ni, kg); + // Loads data: off-chip --> private (matrix B) + #else + bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); + #endif + } // Performs the accumulation (Cpm += Apm * Bpm) - MultiplyAccumulate(cpm, apm, bpm); + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + const realM aval = apm[_mi]; + #if VWN == 1 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); + #elif VWN == 2 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + #elif VWN == 4 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); + #elif VWN == 8 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + #elif VWN == 16 + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); + #endif + } + } + } } #if SA == 1 || SB == 1 @@ -97,6 +156,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #if GLOBAL_MEM_FENCE == 1 barrier(CLK_GLOBAL_MEM_FENCE); #endif + + // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta + StoreResults(cgm, cpm, kSizeM, alpha, beta); } // ================================================================================================= @@ -127,21 +189,16 @@ void XgemmUpper(const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - //#pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm); #else - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm, cpm, kSizeN, alpha, beta); } // Main entry point of the kernel. This is the lower-triangular version. @@ -168,21 +225,16 @@ void XgemmLower(const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - //#pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm); #else - XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm); + XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm, cpm, kSizeN, alpha, beta); } // ================================================================================================= @@ -213,21 +265,16 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, __local realN blm[KWG * NWG/VWN]; #endif - // Computes the matrix-multiplication and stores the result in register memory - //#pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta, blm); #else - XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, alpha, beta); #endif - - // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta - StoreResults(cgm, cpm, kSizeM, alpha, beta); } #endif |