summaryrefslogtreecommitdiff
path: root/src/kernels/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-08 21:07:01 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-08 21:07:01 +0100
commit99a4df88a6d808ea77c9116ce63621503c00b57a (patch)
treed8f8bc1b3884c0340df9f6d95b4837ed3dff8deb /src/kernels/level3
parent13f0f6fc6e612a5f77c6fd78b983f1b2bb8e36b6 (diff)
Implemented the in-direct version of the strided-batched GEMM kernel
Diffstat (limited to 'src/kernels/level3')
-rw-r--r--src/kernels/level3/copy_pad.opencl39
-rw-r--r--src/kernels/level3/transpose_pad.opencl41
-rw-r--r--src/kernels/level3/xgemm_batched.opencl45
3 files changed, 124 insertions, 1 deletions
diff --git a/src/kernels/level3/copy_pad.opencl b/src/kernels/level3/copy_pad.opencl
index 2e191514..3d389b74 100644
--- a/src/kernels/level3/copy_pad.opencl
+++ b/src/kernels/level3/copy_pad.opencl
@@ -174,6 +174,45 @@ void CopyMatrixBatched(const int src_one, const int src_two,
#endif
// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
+void CopyPadMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest,
+ const int do_conjugate) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ _CopyPadMatrix(src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, do_conjugate);
+}
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
+void CopyMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ _CopyMatrix(src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, 0, 0, 0);
+}
+
+#endif
+// =================================================================================================
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/level3/transpose_pad.opencl b/src/kernels/level3/transpose_pad.opencl
index 67c2bf72..e55a8b7c 100644
--- a/src/kernels/level3/transpose_pad.opencl
+++ b/src/kernels/level3/transpose_pad.opencl
@@ -231,6 +231,47 @@ void TransposeMatrixBatched(const int src_one, const int src_two,
#endif
// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
+void TransposePadMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest,
+ const int do_conjugate) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
+ _TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, do_conjugate);
+}
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
+void TransposeMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
+ _TransposeMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, 0, 0, 0);
+}
+
+#endif
+// =================================================================================================
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/level3/xgemm_batched.opencl b/src/kernels/level3/xgemm_batched.opencl
index 372f910b..b51e6298 100644
--- a/src/kernels/level3/xgemm_batched.opencl
+++ b/src/kernels/level3/xgemm_batched.opencl
@@ -17,8 +17,8 @@
R"(
// =================================================================================================
+#if defined(ROUTINE_GEMMBATCHED)
-// Main entry point of the kernel. This is the regular full version.
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas,
@@ -58,6 +58,49 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
}
+#endif
+// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+void XgemmStridedBatched(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real_arg arg_alpha, const real_arg arg_beta,
+ const __global realM* restrict agm, const int a_one, const int a_two,
+ const __global realN* restrict bgm, const int b_one, const int b_two,
+ __global realM* cgm, const int c_one, const int c_two) {
+ const int batch = get_group_id(2);
+ const real alpha = GetRealArg(arg_alpha);
+ const real beta = GetRealArg(arg_beta);
+
+ // Sets the offsets
+ const int a_offset = batch * a_one * a_two;
+ const int b_offset = batch * b_one * b_two;
+ const int c_offset = batch * c_one * c_two;
+ const __global realM* restrict agm_ = &agm[a_offset / VWM];
+ const __global realN* restrict bgm_ = &bgm[b_offset / VWN];
+ __global realM* restrict cgm_ = &cgm[c_offset / VWM];
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in global memory
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm);
+ #elif SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm);
+ #else
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta);
+ #endif
+}
+
+#endif
// =================================================================================================
// End of the C++11 raw string literal