summaryrefslogtreecommitdiff
path: root/src/kernels/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-04-03 21:18:40 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-04-03 21:18:40 +0200
commiteae25f57270e99930cbde4476fe0f54e81cf1e4d (patch)
treeccc803b2cab58e38e15b5a97987aeca4cc635440 /src/kernels/level3
parent63996eb68b6542f5267304e861da969abeff777a (diff)
Added first version of 2D register tiling kernel with A and C transposed as well
Diffstat (limited to 'src/kernels/level3')
-rw-r--r--src/kernels/level3/xgemm_part1.opencl77
-rw-r--r--src/kernels/level3/xgemm_part3.opencl216
2 files changed, 215 insertions, 78 deletions
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
index 4e1c3e61..265bb019 100644
--- a/src/kernels/level3/xgemm_part1.opencl
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -7,15 +7,19 @@
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
-// This file contains an optimized matrix-multiplication kernel inspired by the paper by Matsumoto
-// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable
-// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It
-// supports different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define.
+// This file contains two optimized matrix-multiplication kernels:
+// - Kernel 0: inspired by the paper by Matsumoto et al. and the tutorial on
+// http://www.cedricnugteren.nl/tutorial.php
+// - Kernel 1: inspired by a Qualcomm optimized GPU kernel with 2D register tiling
+// https://developer.qualcomm.com/blog/matrix-multiply-adreno-gpus-part-2-host-code-and-kernel
+// Both are fully configurable (and tunable!) using many parameters. Both kernels support
+// different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define.
//
-// Matrices are accessed as follows:
+// For kernel 0 matrices are accessed as follows:
// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m)
// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n)
// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m)
+// For kernel 1, both A and C are transposed w.r.t. the above
//
// Or as an image (assuming column-major)
// K
@@ -31,7 +35,7 @@
// o-------o o-----o
//
//
-// This kernel is separated into three files. This is part 1 out of 4.
+// This kernel is separated into multiple files. This is part 1 out of 4.
//
// =================================================================================================
@@ -43,6 +47,9 @@ R"(
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
// this kernel file is used outside of the CLBlast library.
+#ifndef GEMMK
+ #define GEMMK 0 // Kernel to choose: 0 regular, 1 with 2D register tiling
+#endif
#ifndef MWG
#define MWG 8 // Tile-size in dimension M (e.g. 64, 128)
#endif
@@ -59,10 +66,10 @@ R"(
#define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32)
#endif
#ifndef MDIMA
- #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
+ #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA (kernel 0 only)
#endif
#ifndef NDIMB
- #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
+ #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB (kernel 0 only)
#endif
#ifndef KWI
#define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG)
@@ -74,16 +81,19 @@ R"(
#define VWN 1 // Vector width of matrix B
#endif
#ifndef STRM
- #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0)
+ #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only)
#endif
#ifndef STRN
- #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0)
+ #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only)
#endif
#ifndef SA
- #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0)
+ #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only)
#endif
#ifndef SB
- #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0)
+ #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only)
+#endif
+#ifndef KREG
+ #define KREG 1 // Amount of register tiling in second dimension, multiple of VWN (kernel 1 only)
#endif
// Helper parameters based on the above tuning parameters
@@ -244,7 +254,7 @@ INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR re
// Caches global off-chip memory directly into per-thread private memory (registers). This function
// is specific for caching the A input matrix.
-#if SA == 0
+#if SA == 0 && GEMMK == 0
INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi,
const int kSizeM, const int idk, const int kwg) {
// Computes the indices based on strided/non-strided access
@@ -263,7 +273,7 @@ INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int
#endif
// Same as above, but now for the B input matrix
-#if SB == 0
+#if SB == 0 && GEMMK == 0
INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni,
const int kSizeN, const int idk) {
// Computes the indices based on strided/non-strided access
@@ -282,6 +292,45 @@ INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int
#endif
// =================================================================================================
+#if GEMMK == 1
+
+// Caches global off-chip memory directly into per-thread private memory (registers). This function
+// is specific for caching the A input matrix for kernel 1.
+INLINE_FUNC realN GlobalToPrivateA2D(const __global real* restrict a_ptr, const int tid_y, const int _ni,
+ const int kSizeK, const int idk, const int _ki) {
+ const int a_index = (tid_y * NWI + _ni) * kSizeK + idk + _ki * VWN;
+ #if VWN == 1
+ return a_ptr[a_index];
+ #elif VWN == 2
+ return vload2(0, a_ptr + a_index);
+ #elif VWN == 4
+ return vload4(0, a_ptr + a_index);
+ #elif VWN == 8
+ return vload8(0, a_ptr + a_index);
+ #elif VWN == 16
+ return vload16(0, a_ptr + a_index);
+ #endif
+}
+
+// Same as above, but now for the B input matrix
+INLINE_FUNC realM GlobalToPrivateB2D(const __global real* restrict b_ptr, const int tid_x, const int _mi,
+ const int kSizeN, const int idk, const int _ki) {
+ const int b_index = (idk + _ki) * kSizeN + tid_x * MWI + _mi * VWM;
+ #if VWM == 1
+ return b_ptr[b_index];
+ #elif VWM == 2
+ return vload2(0, b_ptr + b_index);
+ #elif VWM == 4
+ return vload4(0, b_ptr + b_index);
+ #elif VWM == 8
+ return vload8(0, b_ptr + b_index);
+ #elif VWM == 16
+ return vload16(0, b_ptr + b_index);
+ #endif
+}
+
+#endif
+// =================================================================================================
// Caches on-chip local memory into per-thread private memory (registers). This function is specific
// for caching the A input matrix.
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index 08778f0d..d7ddeb15 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -31,12 +31,26 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
) {
// Allocates workitem-private memory (registers)
+ #if GEMMK == 0
+ #pragma promote_to_registers
+ realM apm[MWI/VWM]; // MWI * 1
+ #pragma promote_to_registers
+ realN bpm[NWI/VWN]; // 1 * NWI
+ #elif GEMMK == 1
+ #pragma promote_to_registers
+ realN apm[NWI*(KREG/VWN)]; // NWI * KREG
+ #pragma promote_to_registers
+ realM bpm[KREG*(MWI/VWM)]; // KREG * MWI
+ #endif
#pragma promote_to_registers
- realM apm[MWI/VWM];
- #pragma promote_to_registers
- realN bpm[NWI/VWN];
- #pragma promote_to_registers
- realM cpm[NWI*(MWI/VWM)];
+ realM cpm[NWI*(MWI/VWM)]; // NWI * MWI
+
+ #if GEMMK == 1
+ const __global real* restrict a_ptr = (const __global real* restrict) &agm[0];
+ const __global real* restrict b_ptr = (const __global real* restrict) &bgm[0];
+ const int tid_x = get_global_id(0);
+ const int tid_y = get_global_id(1);
+ #endif
// Combined thread identifier (volatile to disable caching)
#if SA == 1 || SB == 1
@@ -52,9 +66,8 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
}
}
-
// Loops over all workgroup tiles
- for (int kwg = 0; kwg < kSizeK; kwg += KWG) {
+ for (int kwg = 0; kwg < kSizeK; kwg += KWG * KREG) {
// Loads data: off-chip --> local (matrix A)
#if SA == 1
@@ -69,9 +82,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Loops over all workitem tiles, unrolled by a factor KWI
- for (int pwi = 0; pwi < KWG; pwi += KWI) {
+ for (int pwi = 0; pwi < KWG * KREG; pwi += KWI * KREG) {
#pragma unroll
- for (int _pit = 0; _pit < KWI; _pit += 1) {
+ for (int _pit = 0; _pit < KWI * KREG; _pit += KREG) {
#if SA == 0 || SB == 0
int idk = kwg + pwi + _pit;
#endif
@@ -79,73 +92,143 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
int kg = pwi + _pit;
#endif
+ // Loads matrix A (kernel 0) or matrix B (kernel 1)
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
// Loads data: local --> private (matrix A)
- #if SA == 1
+ #if GEMMK == 0 && SA == 1
apm[_mi] = LocalToPrivateA(alm, _mi, kg);
// Loads data: off-chip --> private (matrix A)
- #else
+ #elif GEMMK == 0 && SA == 0
apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg);
+ // Loads data: 2D global --> 2D private (matrix B)
+ #elif GEMMK == 1
+ #pragma unroll
+ for (int _ki = 0; _ki < KREG; _ki += 1) {
+ bpm[_ki * (MWI/VWM) + _mi] = GlobalToPrivateB2D(b_ptr, tid_x, _mi, kSizeN, idk, _ki);
+ }
#endif
}
- // Loads data: local --> private (matrix B)
- #pragma unroll
- for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
- #if SB == 1
- bpm[_ni] = LocalToPrivateB(blm, _ni, kg);
- // Loads data: off-chip --> private (matrix B)
- #else
- bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk);
- #endif
- }
+ // Loads matrix B (kernel 0) or matrix A (kernel 1)
+ #if GEMMK == 0
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ // Loads data: local --> private (matrix B)
+ #if SB == 1
+ bpm[_ni] = LocalToPrivateB(blm, _ni, kg);
+ // Loads data: off-chip --> private (matrix B)
+ #else
+ bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk);
+ #endif
+ }
+ #elif GEMMK == 1
+ // Loads data: 2D global --> 2D private (matrix A)
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI; _ni += 1) {
+ #pragma unroll
+ for (int _ki = 0; _ki < KREG/VWN; _ki += 1) {
+ apm[_ni * (KREG/VWN) + _ki] = GlobalToPrivateA2D(a_ptr, tid_y, _ni, kSizeK, idk, _ki);
+ }
+ }
+ #endif
// Performs the accumulation (Cpm += Apm * Bpm)
- #pragma unroll
- for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ #if GEMMK == 0
#pragma unroll
- for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
- const realM aval = apm[_mi];
- #if VWN == 1
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
- #elif VWN == 2
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
- cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
- #elif VWN == 4
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
- cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
- cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
- cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
- #elif VWN == 8
- cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
- cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
- cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
- cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
- cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
- cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
- cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
- cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
- #elif VWN == 16
- cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
- cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
- cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
- cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
- cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
- cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
- cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
- cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
- cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
- cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
- cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
- cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
- cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
- cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
- cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
- cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
- #endif
+ for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ const realM aval = apm[_mi];
+ #if VWN == 1
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
+ #elif VWN == 2
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ #elif VWN == 4
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
+ #elif VWN == 8
+ cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ #elif VWN == 16
+ cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
+ cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
+ cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
+ cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
+ cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
+ cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
+ cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
+ cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
+ cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
+ cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
+ cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
+ cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
+ cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
+ cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
+ cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
+ cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
+ #endif
+ }
}
- }
+ #elif GEMMK == 1
+ #pragma unroll
+ for (int _ni = 0; _ni < NWI; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
+ #pragma unroll
+ for (int _ki = 0; _ki < KREG/VWN; _ki += 1) {
+ const int index = _ni * (MWI/VWM) + _mi;
+ const realN aval = apm[_ni * (KREG/VWN) + _ki];
+ #if VWN == 1
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval);
+ #elif VWN == 2
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
+ #elif VWN == 4
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w);
+ #elif VWN == 8
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7);
+ #elif VWN == 16
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE);
+ cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF);
+ #endif
+ }
+ }
+ }
+ #endif
}
}
@@ -158,11 +241,16 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
+ #if GEMMK == 0
+ const int cld = kSizeM;
+ #elif GEMMK == 1
+ const int cld = kSizeN;
+ #endif
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
- StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, kSizeM, alpha, beta);
+ StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, cld, alpha, beta);
}
}
}