diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-08 21:07:01 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-08 21:07:01 +0100 |
commit | 99a4df88a6d808ea77c9116ce63621503c00b57a (patch) | |
tree | d8f8bc1b3884c0340df9f6d95b4837ed3dff8deb /src/kernels/level3/transpose_pad.opencl | |
parent | 13f0f6fc6e612a5f77c6fd78b983f1b2bb8e36b6 (diff) |
Implemented the in-direct version of the strided-batched GEMM kernel
Diffstat (limited to 'src/kernels/level3/transpose_pad.opencl')
-rw-r--r-- | src/kernels/level3/transpose_pad.opencl | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/kernels/level3/transpose_pad.opencl b/src/kernels/level3/transpose_pad.opencl index 67c2bf72..e55a8b7c 100644 --- a/src/kernels/level3/transpose_pad.opencl +++ b/src/kernels/level3/transpose_pad.opencl @@ -231,6 +231,47 @@ void TransposeMatrixBatched(const int src_one, const int src_two, #endif // ================================================================================================= +#if defined(ROUTINE_GEMMSTRIDEDBATCHED) + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) +void TransposePadMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest, + const int do_conjugate) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)]; + _TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, do_conjugate); +} + +// Strided-batched version of the above +__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) +void TransposeMatrixStridedBatched(const int src_one, const int src_two, + const int src_ld, const int src_offset, + const int src_stride, __global const real* restrict src, + const int dest_one, const int dest_two, + const int dest_ld, const int dest_offset, + const int dest_stride, __global real* dest) { + const int batch = get_group_id(2); + const int src_offset_batch = src_offset + src_stride * batch; + const int dest_offset_batch = dest_offset + dest_stride * batch; + real alpha; SetToOne(alpha); + __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)]; + _TransposeMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src, + dest_one, dest_two, dest_ld, dest_offset_batch, dest, + alpha, 0, 0, 0); +} + +#endif +// ================================================================================================= // End of the C++11 raw string literal )" |