diff options
-rw-r--r-- | CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/database/database.cpp | 2 | ||||
-rw-r--r-- | src/database/database.hpp | 1 | ||||
-rw-r--r-- | src/database/kernels/xgemm_direct.hpp | 76 | ||||
-rw-r--r-- | src/kernels/common.opencl | 2 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_direct.opencl | 455 | ||||
-rw-r--r-- | src/routines/level3/xgemm.cpp | 102 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 23 | ||||
-rw-r--r-- | src/tuning/kernels/xgemm_direct.cpp | 191 |
9 files changed, 852 insertions, 3 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 21e38f1d..07cb9283 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,8 @@ endif() # ================================================================================================== # Sets the supported routines and the used kernels. New routines and kernels should be added here. -set(KERNELS copy_fast copy_pad transpose_fast transpose_pad xaxpy xdot xger xgemm xgemv) +set(KERNELS copy_fast copy_pad transpose_fast transpose_pad xaxpy xdot xger + xgemm xgemm_direct xgemv) set(SAMPLE_PROGRAMS_CPP sgemm) set(SAMPLE_PROGRAMS_C sasum dgemv sgemm haxpy cache) set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax) diff --git a/src/database/database.cpp b/src/database/database.cpp index 34c44a29..2696fb9b 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -21,6 +21,7 @@ #include "database/kernels/xgemv_fast_rot.hpp" #include "database/kernels/xger.hpp" #include "database/kernels/xgemm.hpp" +#include "database/kernels/xgemm_direct.hpp" #include "database/kernels/copy.hpp" #include "database/kernels/pad.hpp" #include "database/kernels/transpose.hpp" @@ -38,6 +39,7 @@ const std::vector<Database::DatabaseEntry> Database::database = { XgemvFastRotHalf, XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble, XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble, XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, + XgemmDirectHalf, XgemmDirectSingle, XgemmDirectDouble, XgemmDirectComplexSingle, XgemmDirectComplexDouble, CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble, PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble, TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble, diff --git a/src/database/database.hpp b/src/database/database.hpp index a6ab49c5..7c0afb46 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -75,6 +75,7 @@ class Database { static const DatabaseEntry XgemvFastRotHalf, XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble; static const DatabaseEntry XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble; static const DatabaseEntry XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble; + static const DatabaseEntry XgemmDirectHalf, XgemmDirectSingle, XgemmDirectDouble, XgemmDirectComplexSingle, XgemmDirectComplexDouble; static const DatabaseEntry CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble; static const DatabaseEntry PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble; static const DatabaseEntry TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble; diff --git a/src/database/kernels/xgemm_direct.hpp b/src/database/kernels/xgemm_direct.hpp new file mode 100644 index 00000000..dc69f61b --- /dev/null +++ b/src/database/kernels/xgemm_direct.hpp @@ -0,0 +1,76 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Database generator <database.py> +// +// This file populates the database with best-found tuning parameters for the 'Xgemm' kernels. +// +// ================================================================================================= + +namespace clblast { +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectHalf = { + "XgemmDirect", Precision::kHalf, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectSingle = { + "XgemmDirect", Precision::kSingle, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectComplexSingle = { + "XgemmDirect", Precision::kComplexSingle, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectDouble = { + "XgemmDirect", Precision::kDouble, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectComplexDouble = { + "XgemmDirect", Precision::kComplexDouble, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= +} // namespace clblast diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index 223501fd..b0817242 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -204,7 +204,7 @@ R"( #if PRECISION == 3232 || PRECISION == 6464 #define COMPLEX_CONJUGATE(value) value.x = value.x; value.y = -value.y #else - #define COMPLEX_CONJUGATE(value) value = value + #define COMPLEX_CONJUGATE(value) #endif // ================================================================================================= diff --git a/src/kernels/level3/xgemm_direct.opencl b/src/kernels/level3/xgemm_direct.opencl new file mode 100644 index 00000000..705ced9c --- /dev/null +++ b/src/kernels/level3/xgemm_direct.opencl @@ -0,0 +1,455 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This is a generic GEMM kernel that works for all sizes and configurations: it doesn't require any +// pre and and post-processing kernels. +// +// ================================================================================================= + +// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string +// literal). Comment-out this line for syntax-highlighting when developing. +R"( + +// Parameters set by the tuner or by the database. Here they are given a basic default value in case +// this kernel file is used outside of the CLBlast library. Note that all parameters here have a +// suffix 'D' to denote that they are for the 'direct' version of the GEMM kernel. +#ifndef WGD + #define WGD 8 // Tile-size in dimension M, N, and K (e.g. 8, 16, 32, 64) +#endif +#ifndef MDIMCD + #define MDIMCD 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32) +#endif +#ifndef NDIMCD + #define NDIMCD 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32) +#endif +#ifndef MDIMAD + #define MDIMAD 8 // Re-shaped tile dimension of matrix A: KDIMAD * MDIMAD +#endif +#ifndef NDIMBD + #define NDIMBD 8 // Re-shaped tile dimension of matrix B: KDIMBD * NDIMBD +#endif +#ifndef KWID + #define KWID 1 // Unroll factor of the WGD loop (smaller or equal than WGD) +#endif +#ifndef VWMD + #define VWMD 1 // Vector width of matrices A and C +#endif +#ifndef VWND + #define VWND 1 // Vector width of matrix B +#endif + +// Helper parameters based on the above tuning parameters +#define MWID (WGD/MDIMCD) // Work per work-item (M-dimension) +#define NWID (WGD/NDIMCD) // Work per work-item (N-dimension) +#define KDIMAD ((MDIMCD*NDIMCD)/(MDIMAD)) // Re-shaped tile dimension of matrix A: KDIMAD * MDIMAD +#define KDIMBD ((MDIMCD*NDIMCD)/(NDIMBD)) // Re-shaped tile dimension of matrix B: KDIMBD * NDIMBD +#define MWAD (WGD/MDIMAD) // Amount of loads-per-thread for matrix A (M-dimension) +#define KWAD (WGD/KDIMAD) // Amount of loads-per-thread for matrix A (K-dimension) +#define KWBD (WGD/KDIMBD) // Amount of loads-per-thread for matrix B (K-dimension) +#define NWBD (WGD/NDIMBD) // Amount of loads-per-thread for matrix B (N-dimension) + +// ================================================================================================= + +// Data-widths in dimension M +#if VWMD == 1 + typedef real realMD; +#elif VWMD == 2 + typedef real2 realMD; +#elif VWMD == 4 + typedef real4 realMD; +#elif VWMD == 8 + typedef real8 realMD; +#elif VWMD == 16 + typedef real16 realMD; +#endif + +// Data-widths in dimension N +#if VWND == 1 + typedef real realND; +#elif VWND == 2 + typedef real2 realND; +#elif VWND == 4 + typedef real4 realND; +#elif VWND == 8 + typedef real8 realND; +#elif VWND == 16 + typedef real16 realND; +#endif + +// ================================================================================================= + +// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for +// caching the A input matrix. +inline void GlobalToLocalDirectA(const __global realMD* restrict agm, __local real* alm, + const int a_ld, const int a_offset, const int tid, const int kwg, + const int a_transpose, const int a_conjugate) { + const int la0 = tid % MDIMAD; + const int la1 = tid / MDIMAD; + #pragma unroll + for (int mia=0; mia<MWAD/VWMD; ++mia) { + #pragma unroll + for (int kia=0; kia<KWAD; ++kia) { + + // Computes the indices for the global memory + int mg = mia + la0*(MWAD/VWMD); + int kg = kia + la1*KWAD; + int idm = (a_transpose) ? mg + kwg/VWMD : mg + GetGroupID0()*(WGD/VWMD); + int idk = (a_transpose) ? kg + GetGroupID0()*WGD : kg + kwg; + + // Loads the data from global memory into the local memory + const realMD avec = agm[idk*(a_ld/VWMD) + idm + a_offset]; + #if VWMD == 1 + alm[kg*WGD + mg] = avec; + #elif VWMD == 2 + alm[kg*WGD + mg*VWMD + 0] = avec.x; + alm[kg*WGD + mg*VWMD + 1] = avec.y; + #elif VWMD == 4 + alm[kg*WGD + mg*VWMD + 0] = avec.x; + alm[kg*WGD + mg*VWMD + 1] = avec.y; + alm[kg*WGD + mg*VWMD + 2] = avec.z; + alm[kg*WGD + mg*VWMD + 3] = avec.w; + #elif VWMD == 8 + alm[kg*WGD + mg*VWMD + 0] = avec.s0; + alm[kg*WGD + mg*VWMD + 1] = avec.s1; + alm[kg*WGD + mg*VWMD + 2] = avec.s2; + alm[kg*WGD + mg*VWMD + 3] = avec.s3; + alm[kg*WGD + mg*VWMD + 4] = avec.s4; + alm[kg*WGD + mg*VWMD + 5] = avec.s5; + alm[kg*WGD + mg*VWMD + 6] = avec.s6; + alm[kg*WGD + mg*VWMD + 7] = avec.s7; + #elif VWMD == 16 + alm[kg*WGD + mg*VWMD + 0] = avec.s0; + alm[kg*WGD + mg*VWMD + 1] = avec.s1; + alm[kg*WGD + mg*VWMD + 2] = avec.s2; + alm[kg*WGD + mg*VWMD + 3] = avec.s3; + alm[kg*WGD + mg*VWMD + 4] = avec.s4; + alm[kg*WGD + mg*VWMD + 5] = avec.s5; + alm[kg*WGD + mg*VWMD + 6] = avec.s6; + alm[kg*WGD + mg*VWMD + 7] = avec.s7; + alm[kg*WGD + mg*VWMD + 8] = avec.s8; + alm[kg*WGD + mg*VWMD + 9] = avec.s9; + alm[kg*WGD + mg*VWMD + 10] = avec.sA; + alm[kg*WGD + mg*VWMD + 11] = avec.sB; + alm[kg*WGD + mg*VWMD + 12] = avec.sC; + alm[kg*WGD + mg*VWMD + 13] = avec.sD; + alm[kg*WGD + mg*VWMD + 14] = avec.sE; + alm[kg*WGD + mg*VWMD + 15] = avec.sF; + #endif + if (a_conjugate) { + for (int vm=0; vm<VWMD; ++vm) { + COMPLEX_CONJUGATE(alm[kg*WGD + mg*VWMD + vm]); + } + } + } + } +} + +// Same as above, but now for the B input matrix +inline void GlobalToLocalDirectB(const __global realND* restrict bgm, __local real* blm, + const int b_ld, const int b_offset, const int tid, const int kwg, + const int b_transpose, const int b_conjugate) { + const int lb0 = tid % NDIMBD; + const int lb1 = tid / NDIMBD; + #pragma unroll + for (int kib=0; kib<KWBD; ++kib) { + #pragma unroll + for (int nib=0; nib<NWBD/VWND; ++nib) { + + // Computes the indices for the global memory + int ng = nib + lb0*(NWBD/VWND); + int kg = kib + lb1*KWBD; + int idn = (b_transpose) ? ng + kwg/VWND : ng + GetGroupID1()*(WGD/VWND); + int idk = (b_transpose) ? kg + GetGroupID1()*WGD : kg + kwg; + + // Loads the data from global memory into the local memory + const realND bvec = bgm[idk*(b_ld/VWND) + idn + b_offset]; + #if VWND == 1 + blm[kg*WGD + ng] = bvec; + #elif VWND == 2 + blm[kg*WGD + ng*VWND + 0] = bvec.x; + blm[kg*WGD + ng*VWND + 1] = bvec.y; + #elif VWND == 4 + blm[kg*WGD + ng*VWND + 0] = bvec.x; + blm[kg*WGD + ng*VWND + 1] = bvec.y; + blm[kg*WGD + ng*VWND + 2] = bvec.z; + blm[kg*WGD + ng*VWND + 3] = bvec.w; + #elif VWND == 8 + blm[kg*WGD + ng*VWND + 0] = bvec.s0; + blm[kg*WGD + ng*VWND + 1] = bvec.s1; + blm[kg*WGD + ng*VWND + 2] = bvec.s2; + blm[kg*WGD + ng*VWND + 3] = bvec.s3; + blm[kg*WGD + ng*VWND + 4] = bvec.s4; + blm[kg*WGD + ng*VWND + 5] = bvec.s5; + blm[kg*WGD + ng*VWND + 6] = bvec.s6; + blm[kg*WGD + ng*VWND + 7] = bvec.s7; + #elif VWND == 16 + blm[kg*WGD + ng*VWND + 0] = bvec.s0; + blm[kg*WGD + ng*VWND + 1] = bvec.s1; + blm[kg*WGD + ng*VWND + 2] = bvec.s2; + blm[kg*WGD + ng*VWND + 3] = bvec.s3; + blm[kg*WGD + ng*VWND + 4] = bvec.s4; + blm[kg*WGD + ng*VWND + 5] = bvec.s5; + blm[kg*WGD + ng*VWND + 6] = bvec.s6; + blm[kg*WGD + ng*VWND + 7] = bvec.s7; + blm[kg*WGD + ng*VWND + 8] = bvec.s8; + blm[kg*WGD + ng*VWND + 9] = bvec.s9; + blm[kg*WGD + ng*VWND + 10] = bvec.sA; + blm[kg*WGD + ng*VWND + 11] = bvec.sB; + blm[kg*WGD + ng*VWND + 12] = bvec.sC; + blm[kg*WGD + ng*VWND + 13] = bvec.sD; + blm[kg*WGD + ng*VWND + 14] = bvec.sE; + blm[kg*WGD + ng*VWND + 15] = bvec.sF; + #endif + if (b_conjugate) { + for (int vn=0; vn<VWND; ++vn) { + COMPLEX_CONJUGATE(blm[kg*WGD + ng*VWND + vn]); + } + } + } + } +} + +// ================================================================================================= + +// Caches on-chip local memory into per-thread private memory (registers). This function is specific +// for caching the A input matrix. +inline void LocalToPrivateDirectA(__local real* alm, real apm[MWID], const int kg, + const int a_transpose) { + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + const int mg = mi + get_local_id(0)*MWID; + const int index = (a_transpose) ? mg*WGD + kg : kg*WGD + mg; + apm[mi] = alm[index]; + } +} + +// Same as above, but now for the B input matrix +inline void LocalToPrivateDirectB(__local real* blm, real bpm[NWID], const int kg, + const int b_transpose) { + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + const int ng = ni + get_local_id(1)*NWID; + const int index = (b_transpose) ? ng*WGD + kg : kg*WGD + ng; + bpm[ni] = blm[index]; + } +} + +// ================================================================================================= + +// Initializes the accumulation registers to zero +inline void InitAccRegistersDirect(real cpm[NWID][MWID]) { + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + SetToZero(cpm[ni][mi]); + } + } +} + +// ================================================================================================= + +// Performs the actual computation: Cpm += Apm * Bpm +inline void MultiplyAccumulateDirect(real cpm[NWID][MWID], real apm[MWID], real bpm[NWID]) { + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + MultiplyAdd(cpm[ni][mi], apm[mi], bpm[ni]); + } + } +} + +// ================================================================================================= + +// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication +// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm +inline void StoreResultsDirect(__global real* cgm, real cpm[NWID][MWID], + const int kSizeM, const int kSizeN, + const real alpha, const real beta, + const int c_ld, const int c_offset, const int c_transpose) { + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + int mg = mi + get_local_id(0)*MWID; + int ng = ni + get_local_id(1)*NWID; + int idm = mg + GetGroupID0() * WGD; + int idn = ng + GetGroupID1() * WGD; + + // Determines the destination index + const int c_index = (c_transpose) ? idm*c_ld + idn : idn*c_ld + idm; + + // The final multiplication with alpha and the addition with beta*C + real result; + AXPBY(result, alpha, cpm[ni][mi], beta, cgm[c_index + c_offset]); + cgm[c_index + c_offset] = result; + } + } +} + +// ================================================================================================= + +// Main entry point of the kernel. This is the direct version without restrictions. +__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1))) +__kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK, + const real_arg arg_alpha, + const real_arg arg_beta, + const __global realMD* restrict agm, const int a_offset, const int a_ld, + const __global realND* restrict bgm, const int b_offset, const int b_ld, + __global real* cgm, const int c_offset, const int c_ld, + const int a_transpose, const int b_transpose, const int c_transpose, + const int a_conjugate, const int b_conjugate) { + const real alpha = GetRealArg(arg_alpha); + const real beta = GetRealArg(arg_beta); + + // Extra pointers to scalar versions of global memory + const __global real* restrict agms = (const __global real* restrict) agm; + const __global real* restrict bgms = (const __global real* restrict) bgm; + + // Allocates workgroup-private memory (local memory) + __local real alm[WGD * WGD]; + __local real blm[WGD * WGD]; + + // Combined thread identifier (volatile to disable caching) + volatile int tid = get_local_id(0) + MDIMCD*get_local_id(1); + + // Allocates workitem-private memory (registers) + real apm[MWID]; + real bpm[NWID]; + real cpm[NWID][MWID]; + + // Initializes the accumulation registers + InitAccRegistersDirect(cpm); + + // The faster version of GEMM is not allowed on the (incomplete) borders. Therefore, this section + // processes only the main parts: output blocks of WGD by WGD. + const int idm = get_local_id(0) * MWID + GetGroupID0() * WGD; + const int idn = get_local_id(1) * NWID + GetGroupID1() * WGD; + if ((idm < (kSizeM/WGD)*WGD) && (idn < (kSizeN/WGD)*WGD) && + (a_ld % VWMD == 0) && (b_ld % VWND == 0)) { + + // Loops over all complete workgroup tiles + int kwg = 0; + for (; kwg < (kSizeK/WGD) * WGD; kwg+=WGD) { + + // Loads data: off-chip --> local (matrix A and B) + GlobalToLocalDirectA(agm, alm, a_ld, a_offset, tid, kwg, a_transpose, a_conjugate); + GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, tid, kwg, b_transpose, b_conjugate); + barrier(CLK_LOCAL_MEM_FENCE); + + // Loops over all workitem tiles, unrolled by a factor KWID + for (int pwi=0; pwi<WGD; pwi+=KWID) { + #pragma unroll + for (int pit=0; pit<KWID; ++pit) { + int kg = pwi + pit; + + // Loads data: local --> private (matrix A) + LocalToPrivateDirectA(alm, apm, kg, a_transpose); + + // Loads data: local --> private (matrix B) + LocalToPrivateDirectB(blm, bpm, kg, b_transpose); + + // Performs the accumulation (Cpm += Apm * Bpm) + MultiplyAccumulateDirect(cpm, apm, bpm); + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Loop over the remaining part (incomplete tile in K-dimension) + for (; kwg < kSizeK; ++kwg) { + const int idk = kwg; + + // Loads A into register memory + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + const int a_index = (a_transpose) ? (idm + mi)*a_ld + idk : idk*a_ld + (idm + mi); + apm[mi] = agms[a_index + a_offset]; + if (a_conjugate) { COMPLEX_CONJUGATE(apm[mi]); } + } + + // Loads B into register memory + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + const int b_index = (b_transpose) ? (idn + ni)*b_ld + idk : idk*b_ld + (idn + ni); + bpm[ni] = bgms[b_index + b_offset]; + if (b_conjugate) { COMPLEX_CONJUGATE(bpm[ni]); } + } + + // Performs the accumulation (Cpm += Apm * Bpm) + MultiplyAccumulateDirect(cpm, apm, bpm); + } + + // Stores a tile of results and performs the multiplication with alpha and beta + StoreResultsDirect(cgm, cpm, kSizeM, kSizeN, alpha, beta, c_ld, c_offset, c_transpose); + } + + // Simple but slow version for the parts on the edge (incomplete tiles in M and N-dimensions) + else { + + // Loop over the K-dimension + for (int idk = 0; idk < kSizeK; ++idk) { + + // Loads A into register memory + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + if (idm + mi < kSizeM) { + const int a_index = (a_transpose) ? (idm + mi)*a_ld + idk : idk*a_ld + (idm + mi); + apm[mi] = agms[a_index + a_offset]; + if (a_conjugate) { COMPLEX_CONJUGATE(apm[mi]); } + } + else { + SetToZero(apm[mi]); + } + } + + // Loads B into register memory + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + if (idn + ni < kSizeN) { + const int b_index = (b_transpose) ? (idn + ni)*b_ld + idk : idk*b_ld + (idn + ni); + bpm[ni] = bgms[b_index + b_offset]; + if (b_conjugate) { COMPLEX_CONJUGATE(bpm[ni]); } + } + else { + SetToZero(bpm[ni]); + } + } + + // Performs the accumulation (Cpm += Apm * Bpm) + MultiplyAccumulateDirect(cpm, apm, bpm); + } + + // Stores the results + #pragma unroll + for (int ni=0; ni<NWID; ++ni) { + #pragma unroll + for (int mi=0; mi<MWID; ++mi) { + if ((idm + mi) < kSizeM && (idn + ni) < kSizeN) { + + // Determines the destination index + const int c_index = (c_transpose) ? (idm + mi)*c_ld + (idn + ni) : (idn + ni)*c_ld + (idm + mi); + + // Computes and stores the result + real result; + AXPBY(result, alpha, cpm[ni][mi], beta, cgm[c_index + c_offset]); + cgm[c_index + c_offset] = result; + } + } + } + } +} + +// ================================================================================================= + +// End of the C++11 raw string literal +)" + +// ================================================================================================= diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 0b8e768f..e050e844 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -22,7 +22,8 @@ namespace clblast { // Constructor: forwards to base class constructor template <typename T> Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm"}, PrecisionValue<T>()) { + Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm", "XgemmDirect"}, + PrecisionValue<T>()) { source_string_ = #include "../../kernels/level3/level3.opencl" #include "../../kernels/level3/copy_fast.opencl" @@ -35,6 +36,7 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" #include "../../kernels/level3/xgemm_part3.opencl" + #include "../../kernels/level3/xgemm_direct.opencl" ; } @@ -98,6 +100,44 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, status = TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld); if (ErrorIn(status)) { return status; } + // Optionally runs the direct version of GEMM. TODO: Set this based on the arguments + const auto do_gemm_direct = true; // for now, for testing + if (do_gemm_direct) { + return GemmDirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + } + else { + return GemmIndirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + a_one, a_two, a_want_rotated, + b_one, b_two, b_want_rotated, + c_one, c_two, c_want_rotated); + } +} + +// ================================================================================================= + +// The indirect version of GEMM. This uses the faster but non-general kernel. It has specific +// requirements, but several pre and post-processing kernels take care of those. However, the +// overhead of these extra kernels might not be ideal for certain devices/arguments. +template <typename T> +StatusCode Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate, + const size_t a_one, const size_t a_two, const bool a_want_rotated, + const size_t b_one, const size_t b_two, const bool b_want_rotated, + const size_t c_one, const size_t c_two, const bool c_want_rotated) { + auto status = StatusCode::kSuccess; + // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(m, db_["MWG"]); const auto n_ceiled = Ceil(n, db_["NWG"]); @@ -217,6 +257,66 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, } catch (...) { return StatusCode::kTempBufferAllocFailure; } } + +// ================================================================================================= + +// The direct version of GEMM, requiring just one kernel, no pre or post-processing kernels. +template <typename T> +StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate) { + + // Loads the program from the database + const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_); + + // Retrieves the XgemmDirect kernel from the compiled binary + try { + auto kernel = Kernel(program, "XgemmDirect"); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(m)); + kernel.SetArgument(1, static_cast<int>(n)); + kernel.SetArgument(2, static_cast<int>(k)); + kernel.SetArgument(3, GetRealArg(alpha)); + kernel.SetArgument(4, GetRealArg(beta)); + kernel.SetArgument(5, a_buffer()); + kernel.SetArgument(6, static_cast<int>(a_offset)); + kernel.SetArgument(7, static_cast<int>(a_ld)); + kernel.SetArgument(8, b_buffer()); + kernel.SetArgument(9, static_cast<int>(b_offset)); + kernel.SetArgument(10, static_cast<int>(b_ld)); + kernel.SetArgument(11, c_buffer()); + kernel.SetArgument(12, static_cast<int>(c_offset)); + kernel.SetArgument(13, static_cast<int>(c_ld)); + kernel.SetArgument(14, static_cast<int>(a_do_transpose)); + kernel.SetArgument(15, static_cast<int>(b_do_transpose)); + kernel.SetArgument(16, static_cast<int>(c_do_transpose)); + kernel.SetArgument(17, static_cast<int>(a_conjugate)); + kernel.SetArgument(18, static_cast<int>(b_conjugate)); + + // Computes the global and local thread sizes + const auto m_ceiled = Ceil(m, db_["WGD"]); + const auto n_ceiled = Ceil(n, db_["WGD"]); + const auto global = std::vector<size_t>{ + (m_ceiled * db_["MDIMCD"]) / db_["WGD"], + (n_ceiled * db_["NDIMCD"]) / db_["WGD"] + }; + const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"]}; + + // Launches the kernel + auto status = RunKernel(kernel, queue_, device_, global, local, event_); + if (ErrorIn(status)) { return status; } + + // Successfully finished the computation + return StatusCode::kSuccess; + } catch (...) { return StatusCode::kInvalidKernel; } +} + // ================================================================================================= // Compiles the templated class diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index bc51c7f5..46e12453 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -35,6 +35,29 @@ class Xgemm: public Routine { const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld); + + // Indirect version of GEMM (with pre and post-processing kernels) + StatusCode GemmIndirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate, + const size_t a_one, const size_t a_two, const bool a_want_rotated, + const size_t b_one, const size_t b_two, const bool b_want_rotated, + const size_t c_one, const size_t c_two, const bool c_want_rotated); + + // Direct version of GEMM (no pre and post-processing kernels) + StatusCode GemmDirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate); }; // ================================================================================================= diff --git a/src/tuning/kernels/xgemm_direct.cpp b/src/tuning/kernels/xgemm_direct.cpp new file mode 100644 index 00000000..c2e8710f --- /dev/null +++ b/src/tuning/kernels/xgemm_direct.cpp @@ -0,0 +1,191 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file uses the CLTune auto-tuner to tune the direct xgemm kernels. There are two variations: +// - V==1: This tests some limited set of tuning parameters exhaustively. +// - V==2: This tests a much larger set of tuning parameters by randomly sampling a subset. +// +// ================================================================================================= + +#include <string> +#include <vector> + +#include "utilities.hpp" +#include "tuning/tuning.hpp" + +namespace clblast { +// ================================================================================================= + +// See comment at top of file for a description of the class +template <typename T, int V> +class TuneXgemmDirect { + public: + + // The representative kernel and the source code + static std::string KernelFamily() { return (V==1) ? "xgemm_direct_1" : "xgemm_direct_2"; } + static std::string KernelName() { return "XgemmDirect"; } + static std::string GetSources() { + return + #include "../src/kernels/common.opencl" + #include "../src/kernels/level3/xgemm_direct.opencl" + ; + } + + // The list of arguments relevant for this routine + static std::vector<std::string> GetOptions() { + return {kArgM, kArgN, kArgK, kArgAlpha, kArgBeta, kArgFraction}; + } + + // Tests for valid arguments + static void TestValidArguments(const Arguments<T> &) { } + + // Sets the default values for the arguments + static size_t DefaultM() { return 128; } + static size_t DefaultN() { return 128; } + static size_t DefaultK() { return 128; } + static double DefaultFraction() { return (V==1) ? 1.0 : 16.0; } // test all or sample randomly + + // Describes how to obtain the sizes of the buffers + static size_t GetSizeX(const Arguments<T> &) { return 1; } // N/A for this kernel + static size_t GetSizeY(const Arguments<T> &) { return 1; } // N/A for this kernel + static size_t GetSizeA(const Arguments<T> &args) { return args.m * args.k; } + static size_t GetSizeB(const Arguments<T> &args) { return args.n * args.k; } + static size_t GetSizeC(const Arguments<T> &args) { return args.m * args.n; } + static size_t GetSizeTemp(const Arguments<T> &) { return 1; } // N/A for this kernel + + // Sets the tuning parameters and their possible values + static void SetParameters(cltune::Tuner &tuner, const size_t id) { + if (V==1) { // limited subset of tuning parameters - but explorable exhaustively + tuner.AddParameter(id, "WGD", {8, 16, 32}); + tuner.AddParameter(id, "MDIMCD", {8, 16, 32}); + tuner.AddParameter(id, "NDIMCD", {8, 16, 32}); + tuner.AddParameter(id, "MDIMAD", {8, 16, 32}); + tuner.AddParameter(id, "NDIMBD", {8, 16, 32}); + tuner.AddParameter(id, "KWID", {2}); + tuner.AddParameter(id, "VWMD", {1, 2, 4, 8}); + tuner.AddParameter(id, "VWND", {1, 2, 4, 8}); + } // a lot more tuning parameters - has to be sampled randomly, too much to test all + else { + tuner.AddParameter(id, "WGD", {8, 16, 32, 64, 128}); + tuner.AddParameter(id, "MDIMCD", {8, 16, 32}); + tuner.AddParameter(id, "NDIMCD", {8, 16, 32}); + tuner.AddParameter(id, "MDIMAD", {8, 16, 32}); + tuner.AddParameter(id, "NDIMBD", {8, 16, 32}); + tuner.AddParameter(id, "KWID", {2, 8, 16}); + tuner.AddParameter(id, "VWMD", {1, 2, 4, 8}); + tuner.AddParameter(id, "VWND", {1, 2, 4, 8}); + } + } + + // Sets the constraints + static void SetConstraints(cltune::Tuner &tuner, const size_t id) { + auto MultipleOfX = [] (std::vector<size_t> v) { return IsMultiple(v[0], v[1]); }; + auto MultipleOfXMulY = [] (std::vector<size_t> v) { return IsMultiple(v[0], v[1]*v[2]); }; + auto MultipleOfXMulYDivZ = [] (std::vector<size_t> v) { return IsMultiple(v[0], (v[1]*v[2])/v[3]); }; + // Requirement for unrolling the WGD loop + tuner.AddConstraint(id, MultipleOfX, {"WGD", "KWID"}); + // Required for integer MWID and NWID + tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "MDIMCD", "VWMD"}); + tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "NDIMCD", "VWND"}); + // Required for integer MWIAD and NWIBD + tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "MDIMAD", "VWMD"}); + tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "NDIMBD", "VWND"}); + // WGD has to be a multiple of KDIMAD = ((MDIMCD*NDIMCD)/(MDIMAD)) and KDIMBD = (...) + tuner.AddConstraint(id, MultipleOfXMulYDivZ, {"WGD", "MDIMCD", "NDIMCD", "MDIMAD"}); + tuner.AddConstraint(id, MultipleOfXMulYDivZ, {"WGD", "MDIMCD", "NDIMCD", "NDIMBD"}); + + // Extra constraints for variation 1 to limit the set of options significantly + if (V==1) { + auto IsEqual = [] (std::vector<size_t> v) { return v[0] == v[1]; }; + tuner.AddConstraint(id, IsEqual, {"MDIMCD", "MDIMAD"}); + tuner.AddConstraint(id, IsEqual, {"NDIMCD", "NDIMBD"}); + } + } + + // Sets the local memory size + static void SetLocalMemorySize(cltune::Tuner &tuner, const size_t id, const Arguments<T> &args) { + auto LocalMemorySize = [args] (std::vector<size_t> v) { + return ((v[0]*v[1] + v[2]*v[3])*GetBytes(args.precision)); + }; + tuner.SetLocalMemoryUsage(id, LocalMemorySize, {"WGD", "WGD", "WGD", "WGD"}); + } + + // Sets the base thread configuration + static std::vector<size_t> GlobalSize(const Arguments<T> &args) { return {args.m, args.n}; } + static std::vector<size_t> GlobalSizeRef(const Arguments<T> &args) { return GlobalSize(args); } + static std::vector<size_t> LocalSize() { return {1, 1}; } + static std::vector<size_t> LocalSizeRef() { return {8, 8}; } + + // Transforms the thread configuration based on the parameters + using TransformVector = std::vector<std::vector<std::string>>; + static TransformVector MulLocal() { return {{"MDIMCD", "NDIMCD"}}; } + static TransformVector DivLocal() { return {}; } + static TransformVector MulGlobal() { return {{"MDIMCD", "NDIMCD"}}; } + static TransformVector DivGlobal() { return {{"WGD", "WGD"}}; } + + // Sets the kernel's arguments + static void SetArguments(cltune::Tuner &tuner, const Arguments<T> &args, + std::vector<T> &, std::vector<T> &, + std::vector<T> &a_mat, std::vector<T> &b_mat, std::vector<T> &c_mat, + std::vector<T> &) { + tuner.AddArgumentScalar(static_cast<int>(args.m)); + tuner.AddArgumentScalar(static_cast<int>(args.n)); + tuner.AddArgumentScalar(static_cast<int>(args.k)); + tuner.AddArgumentScalar(GetRealArg(args.alpha)); + tuner.AddArgumentScalar(GetRealArg(args.beta)); + tuner.AddArgumentInput(a_mat); + tuner.AddArgumentScalar(0); // a_offset + tuner.AddArgumentScalar(static_cast<int>(args.k)); // a_ld + tuner.AddArgumentInput(b_mat); + tuner.AddArgumentScalar(0); // b_offset + tuner.AddArgumentScalar(static_cast<int>(args.n)); // b_ld + tuner.AddArgumentOutput(c_mat); + tuner.AddArgumentScalar(0); // c_offset + tuner.AddArgumentScalar(static_cast<int>(args.n)); // c_ld + tuner.AddArgumentScalar(1); // a_do_transpose + tuner.AddArgumentScalar(1); // b_do_transpose + tuner.AddArgumentScalar(1); // c_do_transpose + tuner.AddArgumentScalar(0); // a_conjugate + tuner.AddArgumentScalar(0); // b_conjugate + } + + // Describes how to compute the performance metrics + static size_t GetMetric(const Arguments<T> &args) { + return 2 * args.m * args.n * args.k; + } + static std::string PerformanceUnit() { return "GFLOPS"; } +}; + +// ================================================================================================= +} // namespace clblast + +// Shortcuts to the clblast namespace +using float2 = clblast::float2; +using double2 = clblast::double2; + +// Function to tune a specific variation V (not within the clblast namespace) +template <int V> +void StartVariation(int argc, char *argv[]) { + switch(clblast::GetPrecision(argc, argv)) { + case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemmDirect<half,V>, half>(argc, argv); break; + case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemmDirect<float,V>, float>(argc, argv); break; + case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemmDirect<double,V>, double>(argc, argv); break; + case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemmDirect<float2,V>, float2>(argc, argv); break; + case clblast::Precision::kComplexDouble: clblast::Tuner<clblast::TuneXgemmDirect<double2,V>, double2>(argc, argv); break; + } +} + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + StartVariation<1>(argc, argv); + StartVariation<2>(argc, argv); + return 0; +} + +// ================================================================================================= |