diff options
Diffstat (limited to 'src/kernels/level3/xgemm_part3.opencl')
-rw-r--r-- | src/kernels/level3/xgemm_part3.opencl | 216 |
1 files changed, 152 insertions, 64 deletions
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index 08778f0d..d7ddeb15 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -31,12 +31,26 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, ) { // Allocates workitem-private memory (registers) + #if GEMMK == 0 + #pragma promote_to_registers + realM apm[MWI/VWM]; // MWI * 1 + #pragma promote_to_registers + realN bpm[NWI/VWN]; // 1 * NWI + #elif GEMMK == 1 + #pragma promote_to_registers + realN apm[NWI*(KREG/VWN)]; // NWI * KREG + #pragma promote_to_registers + realM bpm[KREG*(MWI/VWM)]; // KREG * MWI + #endif #pragma promote_to_registers - realM apm[MWI/VWM]; - #pragma promote_to_registers - realN bpm[NWI/VWN]; - #pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + realM cpm[NWI*(MWI/VWM)]; // NWI * MWI + + #if GEMMK == 1 + const __global real* restrict a_ptr = (const __global real* restrict) &agm[0]; + const __global real* restrict b_ptr = (const __global real* restrict) &bgm[0]; + const int tid_x = get_global_id(0); + const int tid_y = get_global_id(1); + #endif // Combined thread identifier (volatile to disable caching) #if SA == 1 || SB == 1 @@ -52,9 +66,8 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, } } - // Loops over all workgroup tiles - for (int kwg = 0; kwg < kSizeK; kwg += KWG) { + for (int kwg = 0; kwg < kSizeK; kwg += KWG * KREG) { // Loads data: off-chip --> local (matrix A) #if SA == 1 @@ -69,9 +82,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Loops over all workitem tiles, unrolled by a factor KWI - for (int pwi = 0; pwi < KWG; pwi += KWI) { + for (int pwi = 0; pwi < KWG * KREG; pwi += KWI * KREG) { #pragma unroll - for (int _pit = 0; _pit < KWI; _pit += 1) { + for (int _pit = 0; _pit < KWI * KREG; _pit += KREG) { #if SA == 0 || SB == 0 int idk = kwg + pwi + _pit; #endif @@ -79,73 +92,143 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, int kg = pwi + _pit; #endif + // Loads matrix A (kernel 0) or matrix B (kernel 1) #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { // Loads data: local --> private (matrix A) - #if SA == 1 + #if GEMMK == 0 && SA == 1 apm[_mi] = LocalToPrivateA(alm, _mi, kg); // Loads data: off-chip --> private (matrix A) - #else + #elif GEMMK == 0 && SA == 0 apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg); + // Loads data: 2D global --> 2D private (matrix B) + #elif GEMMK == 1 + #pragma unroll + for (int _ki = 0; _ki < KREG; _ki += 1) { + bpm[_ki * (MWI/VWM) + _mi] = GlobalToPrivateB2D(b_ptr, tid_x, _mi, kSizeN, idk, _ki); + } #endif } - // Loads data: local --> private (matrix B) - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - #if SB == 1 - bpm[_ni] = LocalToPrivateB(blm, _ni, kg); - // Loads data: off-chip --> private (matrix B) - #else - bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); - #endif - } + // Loads matrix B (kernel 0) or matrix A (kernel 1) + #if GEMMK == 0 + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + // Loads data: local --> private (matrix B) + #if SB == 1 + bpm[_ni] = LocalToPrivateB(blm, _ni, kg); + // Loads data: off-chip --> private (matrix B) + #else + bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); + #endif + } + #elif GEMMK == 1 + // Loads data: 2D global --> 2D private (matrix A) + #pragma unroll + for (int _ni = 0; _ni < NWI; _ni += 1) { + #pragma unroll + for (int _ki = 0; _ki < KREG/VWN; _ki += 1) { + apm[_ni * (KREG/VWN) + _ki] = GlobalToPrivateA2D(a_ptr, tid_y, _ni, kSizeK, idk, _ki); + } + } + #endif // Performs the accumulation (Cpm += Apm * Bpm) - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #if GEMMK == 0 #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - const realM aval = apm[_mi]; - #if VWN == 1 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); - #elif VWN == 2 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - #elif VWN == 4 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); - #elif VWN == 8 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - #elif VWN == 16 - cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); - cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); - cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); - cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); - cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); - cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); - cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); - cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); - #endif + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + const realM aval = apm[_mi]; + #if VWN == 1 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); + #elif VWN == 2 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + #elif VWN == 4 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); + #elif VWN == 8 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + #elif VWN == 16 + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); + #endif + } } - } + #elif GEMMK == 1 + #pragma unroll + for (int _ni = 0; _ni < NWI; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + #pragma unroll + for (int _ki = 0; _ki < KREG/VWN; _ki += 1) { + const int index = _ni * (MWI/VWM) + _mi; + const realN aval = apm[_ni * (KREG/VWN) + _ki]; + #if VWN == 1 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval); + #elif VWN == 2 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + #elif VWN == 4 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w); + #elif VWN == 8 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7); + #elif VWN == 16 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF); + #endif + } + } + } + #endif } } @@ -158,11 +241,16 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta + #if GEMMK == 0 + const int cld = kSizeM; + #elif GEMMK == 1 + const int cld = kSizeN; + #endif #pragma unroll for (int _ni = 0; _ni < NWI; _ni += 1) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, kSizeM, alpha, beta); + StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, cld, alpha, beta); } } } |