From eae25f57270e99930cbde4476fe0f54e81cf1e4d Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Tue, 3 Apr 2018 21:18:40 +0200 Subject: Added first version of 2D register tiling kernel with A and C transposed as well --- src/kernels/level3/xgemm_part1.opencl | 77 +++++++++--- src/kernels/level3/xgemm_part3.opencl | 216 ++++++++++++++++++++++++---------- 2 files changed, 215 insertions(+), 78 deletions(-) (limited to 'src/kernels/level3') diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl index 4e1c3e61..265bb019 100644 --- a/src/kernels/level3/xgemm_part1.opencl +++ b/src/kernels/level3/xgemm_part1.opencl @@ -7,15 +7,19 @@ // Author(s): // Cedric Nugteren // -// This file contains an optimized matrix-multiplication kernel inspired by the paper by Matsumoto -// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable -// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It -// supports different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define. +// This file contains two optimized matrix-multiplication kernels: +// - Kernel 0: inspired by the paper by Matsumoto et al. and the tutorial on +// http://www.cedricnugteren.nl/tutorial.php +// - Kernel 1: inspired by a Qualcomm optimized GPU kernel with 2D register tiling +// https://developer.qualcomm.com/blog/matrix-multiply-adreno-gpus-part-2-host-code-and-kernel +// Both are fully configurable (and tunable!) using many parameters. Both kernels support +// different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define. // -// Matrices are accessed as follows: +// For kernel 0 matrices are accessed as follows: // A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m) // B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n) // C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m) +// For kernel 1, both A and C are transposed w.r.t. the above // // Or as an image (assuming column-major) // K @@ -31,7 +35,7 @@ // o-------o o-----o // // -// This kernel is separated into three files. This is part 1 out of 4. +// This kernel is separated into multiple files. This is part 1 out of 4. // // ================================================================================================= @@ -43,6 +47,9 @@ R"( // Parameters set by the tuner or by the database. Here they are given a basic default value in case // this kernel file is used outside of the CLBlast library. +#ifndef GEMMK + #define GEMMK 0 // Kernel to choose: 0 regular, 1 with 2D register tiling +#endif #ifndef MWG #define MWG 8 // Tile-size in dimension M (e.g. 64, 128) #endif @@ -59,10 +66,10 @@ R"( #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32) #endif #ifndef MDIMA - #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA + #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA (kernel 0 only) #endif #ifndef NDIMB - #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB + #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB (kernel 0 only) #endif #ifndef KWI #define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG) @@ -74,16 +81,19 @@ R"( #define VWN 1 // Vector width of matrix B #endif #ifndef STRM - #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) + #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only) #endif #ifndef STRN - #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) + #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only) #endif #ifndef SA - #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) + #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only) #endif #ifndef SB - #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) + #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only) +#endif +#ifndef KREG + #define KREG 1 // Amount of register tiling in second dimension, multiple of VWN (kernel 1 only) #endif // Helper parameters based on the above tuning parameters @@ -244,7 +254,7 @@ INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR re // Caches global off-chip memory directly into per-thread private memory (registers). This function // is specific for caching the A input matrix. -#if SA == 0 +#if SA == 0 && GEMMK == 0 INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi, const int kSizeM, const int idk, const int kwg) { // Computes the indices based on strided/non-strided access @@ -263,7 +273,7 @@ INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int #endif // Same as above, but now for the B input matrix -#if SB == 0 +#if SB == 0 && GEMMK == 0 INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni, const int kSizeN, const int idk) { // Computes the indices based on strided/non-strided access @@ -281,6 +291,45 @@ INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int } #endif +// ================================================================================================= +#if GEMMK == 1 + +// Caches global off-chip memory directly into per-thread private memory (registers). This function +// is specific for caching the A input matrix for kernel 1. +INLINE_FUNC realN GlobalToPrivateA2D(const __global real* restrict a_ptr, const int tid_y, const int _ni, + const int kSizeK, const int idk, const int _ki) { + const int a_index = (tid_y * NWI + _ni) * kSizeK + idk + _ki * VWN; + #if VWN == 1 + return a_ptr[a_index]; + #elif VWN == 2 + return vload2(0, a_ptr + a_index); + #elif VWN == 4 + return vload4(0, a_ptr + a_index); + #elif VWN == 8 + return vload8(0, a_ptr + a_index); + #elif VWN == 16 + return vload16(0, a_ptr + a_index); + #endif +} + +// Same as above, but now for the B input matrix +INLINE_FUNC realM GlobalToPrivateB2D(const __global real* restrict b_ptr, const int tid_x, const int _mi, + const int kSizeN, const int idk, const int _ki) { + const int b_index = (idk + _ki) * kSizeN + tid_x * MWI + _mi * VWM; + #if VWM == 1 + return b_ptr[b_index]; + #elif VWM == 2 + return vload2(0, b_ptr + b_index); + #elif VWM == 4 + return vload4(0, b_ptr + b_index); + #elif VWM == 8 + return vload8(0, b_ptr + b_index); + #elif VWM == 16 + return vload16(0, b_ptr + b_index); + #endif +} + +#endif // ================================================================================================= // Caches on-chip local memory into per-thread private memory (registers). This function is specific diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl index 08778f0d..d7ddeb15 100644 --- a/src/kernels/level3/xgemm_part3.opencl +++ b/src/kernels/level3/xgemm_part3.opencl @@ -31,12 +31,26 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, ) { // Allocates workitem-private memory (registers) + #if GEMMK == 0 + #pragma promote_to_registers + realM apm[MWI/VWM]; // MWI * 1 + #pragma promote_to_registers + realN bpm[NWI/VWN]; // 1 * NWI + #elif GEMMK == 1 + #pragma promote_to_registers + realN apm[NWI*(KREG/VWN)]; // NWI * KREG + #pragma promote_to_registers + realM bpm[KREG*(MWI/VWM)]; // KREG * MWI + #endif #pragma promote_to_registers - realM apm[MWI/VWM]; - #pragma promote_to_registers - realN bpm[NWI/VWN]; - #pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; + realM cpm[NWI*(MWI/VWM)]; // NWI * MWI + + #if GEMMK == 1 + const __global real* restrict a_ptr = (const __global real* restrict) &agm[0]; + const __global real* restrict b_ptr = (const __global real* restrict) &bgm[0]; + const int tid_x = get_global_id(0); + const int tid_y = get_global_id(1); + #endif // Combined thread identifier (volatile to disable caching) #if SA == 1 || SB == 1 @@ -52,9 +66,8 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, } } - // Loops over all workgroup tiles - for (int kwg = 0; kwg < kSizeK; kwg += KWG) { + for (int kwg = 0; kwg < kSizeK; kwg += KWG * KREG) { // Loads data: off-chip --> local (matrix A) #if SA == 1 @@ -69,9 +82,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Loops over all workitem tiles, unrolled by a factor KWI - for (int pwi = 0; pwi < KWG; pwi += KWI) { + for (int pwi = 0; pwi < KWG * KREG; pwi += KWI * KREG) { #pragma unroll - for (int _pit = 0; _pit < KWI; _pit += 1) { + for (int _pit = 0; _pit < KWI * KREG; _pit += KREG) { #if SA == 0 || SB == 0 int idk = kwg + pwi + _pit; #endif @@ -79,73 +92,143 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, int kg = pwi + _pit; #endif + // Loads matrix A (kernel 0) or matrix B (kernel 1) #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { // Loads data: local --> private (matrix A) - #if SA == 1 + #if GEMMK == 0 && SA == 1 apm[_mi] = LocalToPrivateA(alm, _mi, kg); // Loads data: off-chip --> private (matrix A) - #else + #elif GEMMK == 0 && SA == 0 apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg); + // Loads data: 2D global --> 2D private (matrix B) + #elif GEMMK == 1 + #pragma unroll + for (int _ki = 0; _ki < KREG; _ki += 1) { + bpm[_ki * (MWI/VWM) + _mi] = GlobalToPrivateB2D(b_ptr, tid_x, _mi, kSizeN, idk, _ki); + } #endif } - // Loads data: local --> private (matrix B) - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - #if SB == 1 - bpm[_ni] = LocalToPrivateB(blm, _ni, kg); - // Loads data: off-chip --> private (matrix B) - #else - bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); - #endif - } + // Loads matrix B (kernel 0) or matrix A (kernel 1) + #if GEMMK == 0 + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + // Loads data: local --> private (matrix B) + #if SB == 1 + bpm[_ni] = LocalToPrivateB(blm, _ni, kg); + // Loads data: off-chip --> private (matrix B) + #else + bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); + #endif + } + #elif GEMMK == 1 + // Loads data: 2D global --> 2D private (matrix A) + #pragma unroll + for (int _ni = 0; _ni < NWI; _ni += 1) { + #pragma unroll + for (int _ki = 0; _ki < KREG/VWN; _ki += 1) { + apm[_ni * (KREG/VWN) + _ki] = GlobalToPrivateA2D(a_ptr, tid_y, _ni, kSizeK, idk, _ki); + } + } + #endif // Performs the accumulation (Cpm += Apm * Bpm) - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #if GEMMK == 0 #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - const realM aval = apm[_mi]; - #if VWN == 1 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); - #elif VWN == 2 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - #elif VWN == 4 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); - #elif VWN == 8 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - #elif VWN == 16 - cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); - cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); - cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); - cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); - cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); - cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); - cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); - cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); - #endif + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + const realM aval = apm[_mi]; + #if VWN == 1 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); + #elif VWN == 2 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + #elif VWN == 4 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); + #elif VWN == 8 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + #elif VWN == 16 + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); + #endif + } } - } + #elif GEMMK == 1 + #pragma unroll + for (int _ni = 0; _ni < NWI; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + #pragma unroll + for (int _ki = 0; _ki < KREG/VWN; _ki += 1) { + const int index = _ni * (MWI/VWM) + _mi; + const realN aval = apm[_ni * (KREG/VWN) + _ki]; + #if VWN == 1 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval); + #elif VWN == 2 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + #elif VWN == 4 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w); + #elif VWN == 8 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7); + #elif VWN == 16 + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE); + cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF); + #endif + } + } + } + #endif } } @@ -158,11 +241,16 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta + #if GEMMK == 0 + const int cld = kSizeM; + #elif GEMMK == 1 + const int cld = kSizeN; + #endif #pragma unroll for (int _ni = 0; _ni < NWI; _ni += 1) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, kSizeM, alpha, beta); + StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, cld, alpha, beta); } } } -- cgit v1.2.3