From 99a4df88a6d808ea77c9116ce63621503c00b57a Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Mon, 8 Jan 2018 21:07:01 +0100 Subject: Implemented the in-direct version of the strided-batched GEMM kernel --- src/kernels/level3/copy_pad.opencl | 39 ++++++++++++++++++++++++++++ src/kernels/level3/transpose_pad.opencl | 41 ++++++++++++++++++++++++++++++ src/kernels/level3/xgemm_batched.opencl | 45 ++++++++++++++++++++++++++++++++- 3 files changed, 124 insertions(+), 1 deletion(-) (limited to 'src/kernels/level3') diff --git a/src/kernels/level3/copy_pad.opencl b/src/kernels/level3/copy_pad.opencl index 2e191514..3d389b74 100644 --- a/src/kernels/level3/copy_pad.opencl +++ b/src/kernels/level3/copy_pad.opencl @@ -172,6 +172,45 @@ void CopyMatrixBatched(const int src_one, const int src_two, alpha, 0, 0, 0); } +#endif +// ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1))) +void CopyPadMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest, + const int do_conjugate) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + _CopyPadMatrix(src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, do_conjugate); +} + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1))) +void CopyMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + _CopyMatrix(src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, 0, 0, 0); +} + #endif // ================================================================================================= diff --git a/src/kernels/level3/transpose_pad.opencl b/src/kernels/level3/transpose_pad.opencl index 67c2bf72..e55a8b7c 100644 --- a/src/kernels/level3/transpose_pad.opencl +++ b/src/kernels/level3/transpose_pad.opencl @@ -229,6 +229,47 @@ void TransposeMatrixBatched(const int src_one, const int src_two, alpha, 0, 0, 0); } +#endif +// ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) +void TransposePadMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest, + const int do_conjugate) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)]; + _TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, do_conjugate); +} + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) +void TransposeMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)]; + _TransposeMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, 0, 0, 0); +} + #endif // ================================================================================================= diff --git a/src/kernels/level3/xgemm_batched.opencl b/src/kernels/level3/xgemm_batched.opencl index 372f910b..b51e6298 100644 --- a/src/kernels/level3/xgemm_batched.opencl +++ b/src/kernels/level3/xgemm_batched.opencl @@ -17,8 +17,8 @@ R"( // ================================================================================================= +#if defined(ROUTINE_GEMMBATCHED) -// Main entry point of the kernel. This is the regular full version. __kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, const __constant real_arg* arg_alphas, @@ -58,6 +58,49 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, #endif } +#endif +// ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1))) +void XgemmStridedBatched(const int kSizeM, const int kSizeN, const int kSizeK, + const real_arg arg_alpha, const real_arg arg_beta, + const __global realM* restrict agm, const int a_one, const int a_two, + const __global realN* restrict bgm, const int b_one, const int b_two, + __global realM* cgm, const int c_one, const int c_two) { + const int batch = get_group_id(2); + const real alpha = GetRealArg(arg_alpha); + const real beta = GetRealArg(arg_beta); + + // Sets the offsets + const int a_offset = batch * a_one * a_two; + const int b_offset = batch * b_one * b_two; + const int c_offset = batch * c_one * c_two; + const __global realM* restrict agm_ = &agm[a_offset / VWM]; + const __global realN* restrict bgm_ = &bgm[b_offset / VWN]; + __global realM* restrict cgm_ = &cgm[c_offset / VWM]; + + // Allocates workgroup-private memory (local memory) + #if SA == 1 + __local realM alm[KWG * MWG/VWM]; + #endif + #if SB == 1 + __local realN blm[KWG * NWG/VWN]; + #endif + + // Computes the matrix-multiplication and stores the result in global memory + #if SA == 1 && SB == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm); + #elif SA == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm); + #elif SB == 1 + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm); + #else + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta); + #endif +} + +#endif // ================================================================================================= // End of the C++11 raw string literal -- cgit v1.2.3