summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-02-08 20:06:02 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-02-08 20:06:02 +0100
commitbf84463ab20f2f39071719fad9bd28a6bb13fc24 (patch)
treedf4a6fff31d178186bd36538da3705ccf3353eb3 /src
parent38c56bbde2ed108d47bd058ba239725b3396475d (diff)
Separated the GEMM kernel in two parts to reduce string length for MSVC
Diffstat (limited to 'src')
-rw-r--r--src/kernels/level3/xgemm_part1.opencl329
-rw-r--r--src/kernels/level3/xgemm_part2.opencl (renamed from src/kernels/level3/xgemm.opencl)306
-rw-r--r--src/routines/level3/xgemm.cc3
-rw-r--r--src/routines/level3/xher2k.cc3
-rw-r--r--src/routines/level3/xherk.cc3
-rw-r--r--src/routines/level3/xsyr2k.cc3
-rw-r--r--src/routines/level3/xsyrk.cc3
-rw-r--r--src/tuning/xgemm.cc3
8 files changed, 342 insertions, 311 deletions
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
new file mode 100644
index 00000000..4cb0585b
--- /dev/null
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -0,0 +1,329 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file contains an optimized matrix-multiplication kernel according to the paper by Matsumoto
+// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable
+// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It
+// supports single and double precision (SGEMM/DGEMM) through a pre-processor define.
+//
+// Matrices are accessed as follows:
+// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m)
+// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n)
+// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m)
+//
+// Or as an image (assuming column-major)
+// K
+// o-------o
+// | |
+// N | [B^T] |
+// | |
+// o-------o
+// K N
+// o-------o o-----o
+// M | [A] | M | [C] |
+// | | | |
+// o-------o o-----o
+//
+//
+// This kernel is seperated into two files. This is part 1 out of 2,
+//
+// =================================================================================================
+
+// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
+// literal). Comment-out this line for syntax-highlighting when developing.
+R"(
+
+// =================================================================================================
+
+// Parameters set by the tuner or by the database. Here they are given a basic default value in case
+// this kernel file is used outside of the CLBlast library.
+#ifndef MWG
+ #define MWG 8 // Tile-size in dimension M (e.g. 64, 128)
+#endif
+#ifndef NWG
+ #define NWG 8 // Tile-size in dimension N (e.g. 64, 128)
+#endif
+#ifndef KWG
+ #define KWG 8 // Tile-size in dimension K (e.g. 8, 16)
+#endif
+#ifndef MDIMC
+ #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32)
+#endif
+#ifndef NDIMC
+ #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32)
+#endif
+#ifndef MDIMA
+ #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
+#endif
+#ifndef NDIMB
+ #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
+#endif
+#ifndef KWI
+ #define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG)
+#endif
+#ifndef VWM
+ #define VWM 1 // Vector width of matrices A and C
+#endif
+#ifndef VWN
+ #define VWN 1 // Vector width of matrix B
+#endif
+#ifndef STRM
+ #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0)
+#endif
+#ifndef STRN
+ #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0)
+#endif
+#ifndef SA
+ #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0)
+#endif
+#ifndef SB
+ #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0)
+#endif
+
+// Helper parameters based on the above tuning parameters
+#define MWI (MWG/MDIMC) // Work per work-item (M-dimension)
+#define NWI (NWG/NDIMC) // Work per work-item (N-dimension)
+#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
+#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
+#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension)
+#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension)
+#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension)
+#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension)
+
+// Settings
+#define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually
+
+// =================================================================================================
+
+// Data-widths in dimension M
+#if VWM == 1
+ typedef real realM;
+#elif VWM == 2
+ typedef real2 realM;
+#elif VWM == 4
+ typedef real4 realM;
+#elif VWM == 8
+ typedef real8 realM;
+#elif VWM == 16
+ typedef real16 realM;
+#endif
+
+// Data-widths in dimension N
+#if VWN == 1
+ typedef real realN;
+#elif VWN == 2
+ typedef real2 realN;
+#elif VWN == 4
+ typedef real4 realN;
+#elif VWN == 8
+ typedef real8 realN;
+#elif VWN == 16
+ typedef real16 realN;
+#endif
+
+// =================================================================================================
+
+// Initializes the accumulation registers to zero
+inline void InitAccRegisters(realM cpm[NWI][MWI/VWM]) {
+ #pragma unroll
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+ #pragma unroll
+ for (int ni=0; ni<NWI; ++ni) {
+ #if VWM == 1
+ SetToZero(cpm[ni][mi]);
+ #elif VWM == 2
+ SetToZero(cpm[ni][mi].x);
+ SetToZero(cpm[ni][mi].y);
+ #elif VWM == 4
+ SetToZero(cpm[ni][mi].x);
+ SetToZero(cpm[ni][mi].y);
+ SetToZero(cpm[ni][mi].z);
+ SetToZero(cpm[ni][mi].w);
+ #elif VWM == 8
+ SetToZero(cpm[ni][mi].s0);
+ SetToZero(cpm[ni][mi].s1);
+ SetToZero(cpm[ni][mi].s2);
+ SetToZero(cpm[ni][mi].s3);
+ SetToZero(cpm[ni][mi].s4);
+ SetToZero(cpm[ni][mi].s5);
+ SetToZero(cpm[ni][mi].s6);
+ SetToZero(cpm[ni][mi].s7);
+ #elif VWM == 16
+ SetToZero(cpm[ni][mi].s0);
+ SetToZero(cpm[ni][mi].s1);
+ SetToZero(cpm[ni][mi].s2);
+ SetToZero(cpm[ni][mi].s3);
+ SetToZero(cpm[ni][mi].s4);
+ SetToZero(cpm[ni][mi].s5);
+ SetToZero(cpm[ni][mi].s6);
+ SetToZero(cpm[ni][mi].s7);
+ SetToZero(cpm[ni][mi].s8);
+ SetToZero(cpm[ni][mi].s9);
+ SetToZero(cpm[ni][mi].sA);
+ SetToZero(cpm[ni][mi].sB);
+ SetToZero(cpm[ni][mi].sC);
+ SetToZero(cpm[ni][mi].sD);
+ SetToZero(cpm[ni][mi].sE);
+ SetToZero(cpm[ni][mi].sF);
+ #endif
+ }
+ }
+}
+
+// =================================================================================================
+
+// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
+// caching the A input matrix.
+#if SA == 1
+inline void GlobalToLocalA(const __global realM* restrict agm, __local realM* alm,
+ const int kSizeM, const int tid, const int kwg) {
+ const int la0 = tid % MDIMA;
+ const int la1 = tid / MDIMA;
+ #pragma unroll
+ for (int mia=0; mia<MWA/VWM; ++mia) {
+ #pragma unroll
+ for (int kia=0; kia<KWA; ++kia) {
+
+ // Computes the indices based on strided/non-strided access
+ #if STRM == 0
+ int mg = mia + la0*(MWA/VWM);
+ #elif STRM == 1
+ int mg = la0 + mia*MDIMA;
+ #endif
+
+ // Computes the indices for the global memory
+ int kg = kia + la1*KWA;
+ int idm = mg + get_group_id(0)*(MWG/VWM);
+ int idk = kg + kwg;
+
+ // Loads the data from global memory (not transposed) into the local memory
+ alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm];
+ }
+ }
+}
+#endif
+
+// Same as above, but now for the B input matrix
+#if SB == 1
+inline void GlobalToLocalB(const __global realN* restrict bgm, __local realN* blm,
+ const int kSizeN, const int tid, const int kwg) {
+ const int lb0 = tid % NDIMB;
+ const int lb1 = tid / NDIMB;
+ #pragma unroll
+ for (int kib=0; kib<KWB; ++kib) {
+ #pragma unroll
+ for (int nib=0; nib<NWB/VWN; ++nib) {
+
+ // Computes the indices based on strided/non-strided access
+ #if STRN == 0
+ int ng = nib + lb0*(NWB/VWN);
+ #elif STRN == 1
+ int ng = lb0 + nib*NDIMB;
+ #endif
+
+ // Computes the indices for the global memory
+ int kg = kib + lb1*KWB;
+ int idn = ng + get_group_id(1)*(NWG/VWN);
+ int idk = kg + kwg;
+
+ // Loads the data from global memory (transposed) into the local memory
+ blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn];
+ }
+ }
+}
+#endif
+
+// =================================================================================================
+
+// Caches global off-chip memory directly into per-thread private memory (registers). This function
+// is specific for caching the A input matrix.
+#if SA == 0
+inline void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/VWM],
+ const int kSizeM, const int idk, const int kwg) {
+ #pragma unroll
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+
+ // Computes the indices based on strided/non-strided access
+ #if STRM == 0
+ int mg = mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + mi*MDIMC;
+ #endif
+
+ // Computes the indices for the global memory
+ int idm = mg + get_group_id(0)*(MWG/VWM);
+
+ // Loads the data from global memory (not transposed) and stores into registers
+ apm[mi] = agm[idk*(kSizeM/VWM) + idm];
+ }
+}
+#endif
+
+// Same as above, but now for the B input matrix
+#if SB == 0
+inline void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/VWN],
+ const int kSizeN, const int idk) {
+ #pragma unroll
+ for (int ni=0; ni<NWI/VWN; ++ni) {
+
+ // Computes the indices based on strided/non-strided access
+ #if STRN == 0
+ int ng = ni + get_local_id(1)*(NWI/VWN);
+ #elif STRN == 1
+ int ng = get_local_id(1) + ni*NDIMC;
+ #endif
+
+ // Computes the indices for the global memory
+ int idn = ng + get_group_id(1)*(NWG/VWN);
+
+ // Loads the data from global memory (transposed) and stores into registers
+ bpm[ni] = bgm[idk*(kSizeN/VWN) + idn];
+ }
+}
+#endif
+
+// =================================================================================================
+
+// Caches on-chip local memory into per-thread private memory (registers). This function is specific
+// for caching the A input matrix.
+#if SA == 1
+inline void LocalToPrivateA(__local realM* alm, realM apm[MWI/VWM], const int kg) {
+ #pragma unroll
+ for (int mi=0; mi<MWI/VWM; ++mi) {
+ #if STRM == 0
+ int mg = mi + get_local_id(0)*(MWI/VWM);
+ #elif STRM == 1
+ int mg = get_local_id(0) + mi*MDIMC;
+ #endif
+ apm[mi] = alm[kg*(MWG/VWM) + mg];
+ }
+}
+#endif
+
+// Same as above, but now for the B input matrix
+#if SB == 1
+inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg) {
+ #pragma unroll
+ for (int ni=0; ni<NWI/VWN; ++ni) {
+ #if STRN == 0
+ int ng = ni + get_local_id(1)*(NWI/VWN);
+ #elif STRN == 1
+ int ng = get_local_id(1) + ni*NDIMC;
+ #endif
+ bpm[ni] = blm[kg*(NWG/VWN) + ng];
+ }
+}
+#endif
+
+// =================================================================================================
+
+// End of the C++11 raw string literal
+)"
+
+// =================================================================================================
diff --git a/src/kernels/level3/xgemm.opencl b/src/kernels/level3/xgemm_part2.opencl
index 8db0f557..c0760db6 100644
--- a/src/kernels/level3/xgemm.opencl
+++ b/src/kernels/level3/xgemm_part2.opencl
@@ -7,29 +7,7 @@
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
-// This file contains an optimized matrix-multiplication kernel according to the paper by Matsumoto
-// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable
-// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It
-// supports single and double precision (SGEMM/DGEMM) through a pre-processor define.
-//
-// Matrices are accessed as follows:
-// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m)
-// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n)
-// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m)
-//
-// Or as an image (assuming column-major)
-// K
-// o-------o
-// | |
-// N | [B^T] |
-// | |
-// o-------o
-// K N
-// o-------o o-----o
-// M | [A] | M | [C] |
-// | | | |
-// o-------o o-----o
-//
+// This is part 2 of 2 of the GEMM kernel. See part 1 for more information.
//
// =================================================================================================
@@ -39,288 +17,6 @@ R"(
// =================================================================================================
-// Parameters set by the tuner or by the database. Here they are given a basic default value in case
-// this kernel file is used outside of the CLBlast library.
-#ifndef MWG
- #define MWG 8 // Tile-size in dimension M (e.g. 64, 128)
-#endif
-#ifndef NWG
- #define NWG 8 // Tile-size in dimension N (e.g. 64, 128)
-#endif
-#ifndef KWG
- #define KWG 8 // Tile-size in dimension K (e.g. 8, 16)
-#endif
-#ifndef MDIMC
- #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32)
-#endif
-#ifndef NDIMC
- #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32)
-#endif
-#ifndef MDIMA
- #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
-#endif
-#ifndef NDIMB
- #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
-#endif
-#ifndef KWI
- #define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG)
-#endif
-#ifndef VWM
- #define VWM 1 // Vector width of matrices A and C
-#endif
-#ifndef VWN
- #define VWN 1 // Vector width of matrix B
-#endif
-#ifndef STRM
- #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0)
-#endif
-#ifndef STRN
- #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0)
-#endif
-#ifndef SA
- #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0)
-#endif
-#ifndef SB
- #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0)
-#endif
-
-// Helper parameters based on the above tuning parameters
-#define MWI (MWG/MDIMC) // Work per work-item (M-dimension)
-#define NWI (NWG/NDIMC) // Work per work-item (N-dimension)
-#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
-#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
-#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension)
-#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension)
-#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension)
-#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension)
-
-// Settings
-#define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually
-
-// =================================================================================================
-
-// Data-widths in dimension M
-#if VWM == 1
- typedef real realM;
-#elif VWM == 2
- typedef real2 realM;
-#elif VWM == 4
- typedef real4 realM;
-#elif VWM == 8
- typedef real8 realM;
-#elif VWM == 16
- typedef real16 realM;
-#endif
-
-// Data-widths in dimension N
-#if VWN == 1
- typedef real realN;
-#elif VWN == 2
- typedef real2 realN;
-#elif VWN == 4
- typedef real4 realN;
-#elif VWN == 8
- typedef real8 realN;
-#elif VWN == 16
- typedef real16 realN;
-#endif
-
-// =================================================================================================
-
-// Initializes the accumulation registers to zero
-inline void InitAccRegisters(realM cpm[NWI][MWI/VWM]) {
- #pragma unroll
- for (int mi=0; mi<MWI/VWM; ++mi) {
- #pragma unroll
- for (int ni=0; ni<NWI; ++ni) {
- #if VWM == 1
- SetToZero(cpm[ni][mi]);
- #elif VWM == 2
- SetToZero(cpm[ni][mi].x);
- SetToZero(cpm[ni][mi].y);
- #elif VWM == 4
- SetToZero(cpm[ni][mi].x);
- SetToZero(cpm[ni][mi].y);
- SetToZero(cpm[ni][mi].z);
- SetToZero(cpm[ni][mi].w);
- #elif VWM == 8
- SetToZero(cpm[ni][mi].s0);
- SetToZero(cpm[ni][mi].s1);
- SetToZero(cpm[ni][mi].s2);
- SetToZero(cpm[ni][mi].s3);
- SetToZero(cpm[ni][mi].s4);
- SetToZero(cpm[ni][mi].s5);
- SetToZero(cpm[ni][mi].s6);
- SetToZero(cpm[ni][mi].s7);
- #elif VWM == 16
- SetToZero(cpm[ni][mi].s0);
- SetToZero(cpm[ni][mi].s1);
- SetToZero(cpm[ni][mi].s2);
- SetToZero(cpm[ni][mi].s3);
- SetToZero(cpm[ni][mi].s4);
- SetToZero(cpm[ni][mi].s5);
- SetToZero(cpm[ni][mi].s6);
- SetToZero(cpm[ni][mi].s7);
- SetToZero(cpm[ni][mi].s8);
- SetToZero(cpm[ni][mi].s9);
- SetToZero(cpm[ni][mi].sA);
- SetToZero(cpm[ni][mi].sB);
- SetToZero(cpm[ni][mi].sC);
- SetToZero(cpm[ni][mi].sD);
- SetToZero(cpm[ni][mi].sE);
- SetToZero(cpm[ni][mi].sF);
- #endif
- }
- }
-}
-
-// =================================================================================================
-
-// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
-// caching the A input matrix.
-#if SA == 1
-inline void GlobalToLocalA(const __global realM* restrict agm, __local realM* alm,
- const int kSizeM, const int tid, const int kwg) {
- const int la0 = tid % MDIMA;
- const int la1 = tid / MDIMA;
- #pragma unroll
- for (int mia=0; mia<MWA/VWM; ++mia) {
- #pragma unroll
- for (int kia=0; kia<KWA; ++kia) {
-
- // Computes the indices based on strided/non-strided access
- #if STRM == 0
- int mg = mia + la0*(MWA/VWM);
- #elif STRM == 1
- int mg = la0 + mia*MDIMA;
- #endif
-
- // Computes the indices for the global memory
- int kg = kia + la1*KWA;
- int idm = mg + get_group_id(0)*(MWG/VWM);
- int idk = kg + kwg;
-
- // Loads the data from global memory (not transposed) into the local memory
- alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm];
- }
- }
-}
-#endif
-
-// Same as above, but now for the B input matrix
-#if SB == 1
-inline void GlobalToLocalB(const __global realN* restrict bgm, __local realN* blm,
- const int kSizeN, const int tid, const int kwg) {
- const int lb0 = tid % NDIMB;
- const int lb1 = tid / NDIMB;
- #pragma unroll
- for (int kib=0; kib<KWB; ++kib) {
- #pragma unroll
- for (int nib=0; nib<NWB/VWN; ++nib) {
-
- // Computes the indices based on strided/non-strided access
- #if STRN == 0
- int ng = nib + lb0*(NWB/VWN);
- #elif STRN == 1
- int ng = lb0 + nib*NDIMB;
- #endif
-
- // Computes the indices for the global memory
- int kg = kib + lb1*KWB;
- int idn = ng + get_group_id(1)*(NWG/VWN);
- int idk = kg + kwg;
-
- // Loads the data from global memory (transposed) into the local memory
- blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn];
- }
- }
-}
-#endif
-
-// =================================================================================================
-
-// Caches global off-chip memory directly into per-thread private memory (registers). This function
-// is specific for caching the A input matrix.
-#if SA == 0
-inline void GlobalToPrivateA(const __global realM* restrict agm, realM apm[MWI/VWM],
- const int kSizeM, const int idk, const int kwg) {
- #pragma unroll
- for (int mi=0; mi<MWI/VWM; ++mi) {
-
- // Computes the indices based on strided/non-strided access
- #if STRM == 0
- int mg = mi + get_local_id(0)*(MWI/VWM);
- #elif STRM == 1
- int mg = get_local_id(0) + mi*MDIMC;
- #endif
-
- // Computes the indices for the global memory
- int idm = mg + get_group_id(0)*(MWG/VWM);
-
- // Loads the data from global memory (not transposed) and stores into registers
- apm[mi] = agm[idk*(kSizeM/VWM) + idm];
- }
-}
-#endif
-
-// Same as above, but now for the B input matrix
-#if SB == 0
-inline void GlobalToPrivateB(const __global realN* restrict bgm, realN bpm[NWI/VWN],
- const int kSizeN, const int idk) {
- #pragma unroll
- for (int ni=0; ni<NWI/VWN; ++ni) {
-
- // Computes the indices based on strided/non-strided access
- #if STRN == 0
- int ng = ni + get_local_id(1)*(NWI/VWN);
- #elif STRN == 1
- int ng = get_local_id(1) + ni*NDIMC;
- #endif
-
- // Computes the indices for the global memory
- int idn = ng + get_group_id(1)*(NWG/VWN);
-
- // Loads the data from global memory (transposed) and stores into registers
- bpm[ni] = bgm[idk*(kSizeN/VWN) + idn];
- }
-}
-#endif
-
-// =================================================================================================
-
-// Caches on-chip local memory into per-thread private memory (registers). This function is specific
-// for caching the A input matrix.
-#if SA == 1
-inline void LocalToPrivateA(__local realM* alm, realM apm[MWI/VWM], const int kg) {
- #pragma unroll
- for (int mi=0; mi<MWI/VWM; ++mi) {
- #if STRM == 0
- int mg = mi + get_local_id(0)*(MWI/VWM);
- #elif STRM == 1
- int mg = get_local_id(0) + mi*MDIMC;
- #endif
- apm[mi] = alm[kg*(MWG/VWM) + mg];
- }
-}
-#endif
-
-// Same as above, but now for the B input matrix
-#if SB == 1
-inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg) {
- #pragma unroll
- for (int ni=0; ni<NWI/VWN; ++ni) {
- #if STRN == 0
- int ng = ni + get_local_id(1)*(NWI/VWN);
- #elif STRN == 1
- int ng = get_local_id(1) + ni*NDIMC;
- #endif
- bpm[ni] = blm[kg*(NWG/VWN) + ng];
- }
-}
-#endif
-
-// =================================================================================================
-
// The vectorised multiply-add function
inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) {
#if USE_VECTOR_MAD == 1
diff --git a/src/routines/level3/xgemm.cc b/src/routines/level3/xgemm.cc
index 3961a3fd..5dc2ad7f 100644
--- a/src/routines/level3/xgemm.cc
+++ b/src/routines/level3/xgemm.cc
@@ -36,7 +36,8 @@ Xgemm<T>::Xgemm(Queue &queue, Event &event, const std::string &name):
#include "../../kernels/level3/pad.opencl"
#include "../../kernels/level3/transpose.opencl"
#include "../../kernels/level3/padtranspose.opencl"
- #include "../../kernels/level3/xgemm.opencl"
+ #include "../../kernels/level3/xgemm_part1.opencl"
+ #include "../../kernels/level3/xgemm_part2.opencl"
;
}
diff --git a/src/routines/level3/xher2k.cc b/src/routines/level3/xher2k.cc
index e9970fd1..1711905d 100644
--- a/src/routines/level3/xher2k.cc
+++ b/src/routines/level3/xher2k.cc
@@ -34,7 +34,8 @@ Xher2k<T,U>::Xher2k(Queue &queue, Event &event, const std::string &name):
#include "../../kernels/level3/pad.opencl"
#include "../../kernels/level3/transpose.opencl"
#include "../../kernels/level3/padtranspose.opencl"
- #include "../../kernels/level3/xgemm.opencl"
+ #include "../../kernels/level3/xgemm_part1.opencl"
+ #include "../../kernels/level3/xgemm_part2.opencl"
;
}
diff --git a/src/routines/level3/xherk.cc b/src/routines/level3/xherk.cc
index 49fd12af..cbd0a188 100644
--- a/src/routines/level3/xherk.cc
+++ b/src/routines/level3/xherk.cc
@@ -34,7 +34,8 @@ Xherk<T,U>::Xherk(Queue &queue, Event &event, const std::string &name):
#include "../../kernels/level3/pad.opencl"
#include "../../kernels/level3/transpose.opencl"
#include "../../kernels/level3/padtranspose.opencl"
- #include "../../kernels/level3/xgemm.opencl"
+ #include "../../kernels/level3/xgemm_part1.opencl"
+ #include "../../kernels/level3/xgemm_part2.opencl"
;
}
diff --git a/src/routines/level3/xsyr2k.cc b/src/routines/level3/xsyr2k.cc
index 966a000f..79090871 100644
--- a/src/routines/level3/xsyr2k.cc
+++ b/src/routines/level3/xsyr2k.cc
@@ -36,7 +36,8 @@ Xsyr2k<T>::Xsyr2k(Queue &queue, Event &event, const std::string &name):
#include "../../kernels/level3/pad.opencl"
#include "../../kernels/level3/transpose.opencl"
#include "../../kernels/level3/padtranspose.opencl"
- #include "../../kernels/level3/xgemm.opencl"
+ #include "../../kernels/level3/xgemm_part1.opencl"
+ #include "../../kernels/level3/xgemm_part2.opencl"
;
}
diff --git a/src/routines/level3/xsyrk.cc b/src/routines/level3/xsyrk.cc
index 630cb731..ca429bd7 100644
--- a/src/routines/level3/xsyrk.cc
+++ b/src/routines/level3/xsyrk.cc
@@ -36,7 +36,8 @@ Xsyrk<T>::Xsyrk(Queue &queue, Event &event, const std::string &name):
#include "../../kernels/level3/pad.opencl"
#include "../../kernels/level3/transpose.opencl"
#include "../../kernels/level3/padtranspose.opencl"
- #include "../../kernels/level3/xgemm.opencl"
+ #include "../../kernels/level3/xgemm_part1.opencl"
+ #include "../../kernels/level3/xgemm_part2.opencl"
;
}
diff --git a/src/tuning/xgemm.cc b/src/tuning/xgemm.cc
index c06e3e72..2b4ff456 100644
--- a/src/tuning/xgemm.cc
+++ b/src/tuning/xgemm.cc
@@ -31,7 +31,8 @@ class TuneXgemm {
static std::string GetSources() {
return
#include "../src/kernels/common.opencl"
- #include "../src/kernels/level3/xgemm.opencl"
+ #include "../src/kernels/level3/xgemm_part1.opencl"
+ #include "../src/kernels/level3/xgemm_part2.opencl"
;
}