diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-08 21:07:01 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-08 21:07:01 +0100 |
commit | 99a4df88a6d808ea77c9116ce63621503c00b57a (patch) | |
tree | d8f8bc1b3884c0340df9f6d95b4837ed3dff8deb | |
parent | 13f0f6fc6e612a5f77c6fd78b983f1b2bb8e36b6 (diff) |
Implemented the in-direct version of the strided-batched GEMM kernel
-rw-r--r-- | CHANGELOG | 3 | ||||
-rw-r--r-- | src/kernels/level3/copy_pad.opencl | 39 | ||||
-rw-r--r-- | src/kernels/level3/transpose_pad.opencl | 41 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_batched.opencl | 45 | ||||
-rw-r--r-- | src/routines/common.hpp | 66 | ||||
-rw-r--r-- | src/routines/levelx/xgemmstridedbatched.cpp | 65 |
6 files changed, 215 insertions, 44 deletions
@@ -10,6 +10,9 @@ Development (next version) - Improved compilation time by splitting the tuning database into multiple compilation units - Various minor fixes and enhancements - Added tuned parameters for various devices (see README) +- Added a strided-batched (not part of the BLAS standard) routine, faster but less generic compared + to the existing xGEMMBATCHED routines: + * SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED Version 1.2.0 - Fixed a bug in the TRSM/TRSV routines due to missing synchronisations after GEMM/GEMV calls diff --git a/src/kernels/level3/copy_pad.opencl b/src/kernels/level3/copy_pad.opencl index 2e191514..3d389b74 100644 --- a/src/kernels/level3/copy_pad.opencl +++ b/src/kernels/level3/copy_pad.opencl @@ -174,6 +174,45 @@ void CopyMatrixBatched(const int src_one, const int src_two, #endif // ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1))) +void CopyPadMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest, + const int do_conjugate) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + _CopyPadMatrix(src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, do_conjugate); +} + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1))) +void CopyMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + _CopyMatrix(src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, 0, 0, 0); +} + +#endif +// ================================================================================================= // End of the C++11 raw string literal )" diff --git a/src/kernels/level3/transpose_pad.opencl b/src/kernels/level3/transpose_pad.opencl index 67c2bf72..e55a8b7c 100644 --- a/src/kernels/level3/transpose_pad.opencl +++ b/src/kernels/level3/transpose_pad.opencl @@ -231,6 +231,47 @@ void TransposeMatrixBatched(const int src_one, const int src_two, #endif // ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) +void TransposePadMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest, + const int do_conjugate) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)]; + _TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, do_conjugate); +} + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) +void TransposeMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)]; + _TransposeMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, 0, 0, 0); +} + +#endif +// ================================================================================================= // End of the C++11 raw string literal )" diff --git a/src/kernels/level3/xgemm_batched.opencl b/src/kernels/level3/xgemm_batched.opencl index 372f910b..b51e6298 100644 --- a/src/kernels/level3/xgemm_batched.opencl +++ b/src/kernels/level3/xgemm_batched.opencl @@ -17,8 +17,8 @@ R"( // ================================================================================================= +#if defined(ROUTINE_GEMMBATCHED) -// Main entry point of the kernel. This is the regular full version. __kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, const __constant real_arg* arg_alphas, @@ -58,6 +58,49 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, #endif } +#endif +// ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) +void XgemmStridedBatched(const int kSizeM, const int kSizeN, const int kSizeK, + const real_arg arg_alpha, const real_arg arg_beta, + const __global realM* restrict agm, const int a_one, const int a_two, + const __global realN* restrict bgm, const int b_one, const int b_two, + __global realM* cgm, const int c_one, const int c_two) { + const int batch = get_group_id(2); + const real alpha = GetRealArg(arg_alpha); + const real beta = GetRealArg(arg_beta); + + // Sets the offsets + const int a_offset = batch * a_one * a_two; + const int b_offset = batch * b_one * b_two; + const int c_offset = batch * c_one * c_two; + const __global realM* restrict agm_ = &agm[a_offset / VWM]; + const __global realN* restrict bgm_ = &bgm[b_offset / VWN]; + __global realM* restrict cgm_ = &cgm[c_offset / VWM]; + + // Allocates workgroup-private memory (local memory) + #if SA == 1 + __local realM alm[KWG * MWG/VWM]; + #endif + #if SB == 1 + __local realN blm[KWG * NWG/VWN]; + #endif + + // Computes the matrix-multiplication and stores the result in global memory + #if SA == 1 && SB == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm); + #elif SA == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm); + #elif SB == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm); + #else + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta); + #endif +} + +#endif // ================================================================================================= // End of the C++11 raw string literal diff --git a/src/routines/common.hpp b/src/routines/common.hpp index 06d001d9..6cbe1e1b 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -239,6 +239,72 @@ void PadCopyTransposeMatrixBatched(Queue &queue, const Device &device, } } +// Batched version of the above +template <typename T> +void PadCopyTransposeMatrixStridedBatched(Queue &queue, const Device &device, + const Databases &db, + EventPointer event, const std::vector<Event> &waitForEvents, + const size_t src_one, const size_t src_two, + const size_t src_ld, const size_t src_offset, + const size_t src_stride, const Buffer<T> &src, + const size_t dest_one, const size_t dest_two, + const size_t dest_ld, const size_t dest_offset, + const size_t dest_stride, const Buffer<T> &dest, + const Program &program, const bool do_pad, + const bool do_transpose, const bool do_conjugate, + const size_t batch_count) { + + // Determines the right kernel + auto kernel_name = std::string{}; + if (do_transpose) { + kernel_name = (do_pad) ? "TransposePadMatrixStridedBatched" : "TransposeMatrixStridedBatched"; + } + else { + kernel_name = (do_pad) ? "CopyPadMatrixStridedBatched" : "CopyMatrixStridedBatched"; + } + + // Retrieves the kernel from the compiled binary + auto kernel = Kernel(program, kernel_name); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(src_one)); + kernel.SetArgument(1, static_cast<int>(src_two)); + kernel.SetArgument(2, static_cast<int>(src_ld)); + kernel.SetArgument(3, static_cast<int>(src_offset)); + kernel.SetArgument(4, static_cast<int>(src_stride)); + kernel.SetArgument(5, src()); + kernel.SetArgument(6, static_cast<int>(dest_one)); + kernel.SetArgument(7, static_cast<int>(dest_two)); + kernel.SetArgument(8, static_cast<int>(dest_ld)); + kernel.SetArgument(9, static_cast<int>(dest_offset)); + kernel.SetArgument(10, static_cast<int>(dest_stride)); + kernel.SetArgument(11, dest()); + if (do_pad) { + kernel.SetArgument(12, static_cast<int>(do_conjugate)); + } + + // Launches the kernel and returns the error code. Uses global and local thread sizes based on + // parameters in the database. + if (do_transpose) { + const auto global = std::vector<size_t>{ + Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]), + Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"]), + batch_count + }; + const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"], 1}; + RunKernel(kernel, queue, device, global, local, event, waitForEvents); + } + else { + const auto global = std::vector<size_t>{ + Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]), + Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"]), + batch_count + }; + const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"], 1}; + RunKernel(kernel, queue, device, global, local, event, waitForEvents); + } +} + // ================================================================================================= } // namespace clblast diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index ddf7d878..affbceee 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -112,7 +112,7 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const size_t b_one, const size_t b_two, const size_t c_one, const size_t c_two, const size_t batch_count) { - /* TODO + // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]); const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]); @@ -124,18 +124,10 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); - // Sets the "internal" offsets, i.e. the perfect offsets - auto a_offsets_i = 0;//std::vector<int>(batch_count); - auto b_offsets_i = 0;//std::vector<int>(batch_count); - auto c_offsets_i = 0;//std::vector<int>(batch_count); - // Determines whether or not temporary matrices are needed - auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i && - !a_do_transpose && !a_conjugate; - auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i && - !b_do_transpose && !b_conjugate; - auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i && - !c_do_transpose; + auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && !a_do_transpose && !a_conjugate; + auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && !b_do_transpose && !b_conjugate; + auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && !c_do_transpose; // Creates the temporary matrices const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i); @@ -150,43 +142,31 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, // to fill it up until it reaches a certain multiple of size (kernel parameter dependent). In // case nothing has to be done, these kernels can be skipped. if (!a_no_temp) { - auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count); - auto a_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count); - a_offsets_device.Write(queue_, batch_count, a_offsets); - a_offsets_i_device.Write(queue_, batch_count, a_offsets_i); auto eventProcessA = Event(); - PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, - a_one, a_two, a_ld, a_offsets_device, a_buffer, - a_one_i, a_two_i, a_one_i, a_offsets_i_device, a_temp, - program_, true, a_do_transpose, a_conjugate, batch_count); + PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, + a_one, a_two, a_ld, a_offset, a_stride, a_buffer, + a_one_i, a_two_i, a_one_i, 0, a_one_i * a_two_i, a_temp, + program_, true, a_do_transpose, a_conjugate, batch_count); eventWaitList.push_back(eventProcessA); } // As above, but now for matrix B if (!b_no_temp) { - auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count); - auto b_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count); - b_offsets_device.Write(queue_, batch_count, b_offsets); - b_offsets_i_device.Write(queue_, batch_count, b_offsets_i); auto eventProcessB = Event(); - PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList, - b_one, b_two, b_ld, b_offsets_device, b_buffer, - b_one_i, b_two_i, b_one_i, b_offsets_i_device, b_temp, - program_, true, b_do_transpose, b_conjugate, batch_count); + PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList, + b_one, b_two, b_ld, b_offset, b_stride, b_buffer, + b_one_i, b_two_i, b_one_i, 0, b_one_i * b_two_i, b_temp, + program_, true, b_do_transpose, b_conjugate, batch_count); eventWaitList.push_back(eventProcessB); } // As above, but now for matrix C - auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count); - auto c_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count); if (!c_no_temp) { - c_offsets_device.Write(queue_, batch_count, c_offsets); - c_offsets_i_device.Write(queue_, batch_count, c_offsets_i); auto eventProcessC = Event(); - PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, - c_one, c_two, c_ld, c_offsets_device, c_buffer, - c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp, - program_, true, c_do_transpose, false, batch_count); + PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, + c_one, c_two, c_ld, c_offset, c_stride, c_buffer, + c_one_i, c_two_i, c_one_i, 0, c_one_i * c_two_i, c_temp, + program_, true, c_do_transpose, false, batch_count); eventWaitList.push_back(eventProcessC); } @@ -197,8 +177,8 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, kernel.SetArgument(0, static_cast<int>(m_ceiled)); kernel.SetArgument(1, static_cast<int>(n_ceiled)); kernel.SetArgument(2, static_cast<int>(k_ceiled)); - kernel.SetArgument(3, alpha); - kernel.SetArgument(4, beta); + kernel.SetArgument(3, GetRealArg(alpha)); + kernel.SetArgument(4, GetRealArg(beta)); kernel.SetArgument(5, a_temp()); kernel.SetArgument(6, static_cast<int>(a_one_i)); kernel.SetArgument(7, static_cast<int>(a_two_i)); @@ -225,12 +205,11 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, // Runs the post-processing kernel if needed if (!c_no_temp) { eventWaitList.push_back(eventKernel); - PadCopyTransposeMatrixBatched(queue_, device_, db_, event_, eventWaitList, - c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp, - c_one, c_two, c_ld, c_offsets_device, c_buffer, - program_, false, c_do_transpose, false, batch_count); + PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, event_, eventWaitList, + c_one_i, c_two_i, c_one_i, 0, c_one_i * c_two_i, c_temp, + c_one, c_two, c_ld, c_offset, c_stride, c_buffer, + program_, false, c_do_transpose, false, batch_count); } - */ } // ================================================================================================= |