summaryrefslogtreecommitdiff
path: root/src/kernels/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-11 19:42:50 +0100
committerGitHub <noreply@github.com>2018-01-11 19:42:50 +0100
commit9b084d04093fdbfb22ee4790c6b3db5c55cd2719 (patch)
treed8f8bc1b3884c0340df9f6d95b4837ed3dff8deb /src/kernels/level3
parentc988c2cdd166ebf6d5b5ec20f445de1a95a65b16 (diff)
parent99a4df88a6d808ea77c9116ce63621503c00b57a (diff)
Merge pull request #239 from CNugteren/gemm_strided_batched
GemmStridedBatched
Diffstat (limited to 'src/kernels/level3')
-rw-r--r--src/kernels/level3/copy_pad.opencl39
-rw-r--r--src/kernels/level3/transpose_pad.opencl41
-rw-r--r--src/kernels/level3/xgemm_batched.opencl45
-rw-r--r--src/kernels/level3/xgemm_direct_batched.opencl122
4 files changed, 226 insertions, 21 deletions
diff --git a/src/kernels/level3/copy_pad.opencl b/src/kernels/level3/copy_pad.opencl
index 2e191514..3d389b74 100644
--- a/src/kernels/level3/copy_pad.opencl
+++ b/src/kernels/level3/copy_pad.opencl
@@ -174,6 +174,45 @@ void CopyMatrixBatched(const int src_one, const int src_two,
#endif
// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
+void CopyPadMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest,
+ const int do_conjugate) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ _CopyPadMatrix(src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, do_conjugate);
+}
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
+void CopyMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ _CopyMatrix(src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, 0, 0, 0);
+}
+
+#endif
+// =================================================================================================
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/level3/transpose_pad.opencl b/src/kernels/level3/transpose_pad.opencl
index 67c2bf72..e55a8b7c 100644
--- a/src/kernels/level3/transpose_pad.opencl
+++ b/src/kernels/level3/transpose_pad.opencl
@@ -231,6 +231,47 @@ void TransposeMatrixBatched(const int src_one, const int src_two,
#endif
// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
+void TransposePadMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest,
+ const int do_conjugate) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
+ _TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, do_conjugate);
+}
+
+// Strided-batched version of the above
+__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
+void TransposeMatrixStridedBatched(const int src_one, const int src_two,
+ const int src_ld, const int src_offset,
+ const int src_stride, __global const real* restrict src,
+ const int dest_one, const int dest_two,
+ const int dest_ld, const int dest_offset,
+ const int dest_stride, __global real* dest) {
+ const int batch = get_group_id(2);
+ const int src_offset_batch = src_offset + src_stride * batch;
+ const int dest_offset_batch = dest_offset + dest_stride * batch;
+ real alpha; SetToOne(alpha);
+ __local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
+ _TransposeMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src,
+ dest_one, dest_two, dest_ld, dest_offset_batch, dest,
+ alpha, 0, 0, 0);
+}
+
+#endif
+// =================================================================================================
// End of the C++11 raw string literal
)"
diff --git a/src/kernels/level3/xgemm_batched.opencl b/src/kernels/level3/xgemm_batched.opencl
index 372f910b..b51e6298 100644
--- a/src/kernels/level3/xgemm_batched.opencl
+++ b/src/kernels/level3/xgemm_batched.opencl
@@ -17,8 +17,8 @@
R"(
// =================================================================================================
+#if defined(ROUTINE_GEMMBATCHED)
-// Main entry point of the kernel. This is the regular full version.
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas,
@@ -58,6 +58,49 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
}
+#endif
+// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
+void XgemmStridedBatched(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real_arg arg_alpha, const real_arg arg_beta,
+ const __global realM* restrict agm, const int a_one, const int a_two,
+ const __global realN* restrict bgm, const int b_one, const int b_two,
+ __global realM* cgm, const int c_one, const int c_two) {
+ const int batch = get_group_id(2);
+ const real alpha = GetRealArg(arg_alpha);
+ const real beta = GetRealArg(arg_beta);
+
+ // Sets the offsets
+ const int a_offset = batch * a_one * a_two;
+ const int b_offset = batch * b_one * b_two;
+ const int c_offset = batch * c_one * c_two;
+ const __global realM* restrict agm_ = &agm[a_offset / VWM];
+ const __global realN* restrict bgm_ = &bgm[b_offset / VWN];
+ __global realM* restrict cgm_ = &cgm[c_offset / VWM];
+
+ // Allocates workgroup-private memory (local memory)
+ #if SA == 1
+ __local realM alm[KWG * MWG/VWM];
+ #endif
+ #if SB == 1
+ __local realN blm[KWG * NWG/VWN];
+ #endif
+
+ // Computes the matrix-multiplication and stores the result in global memory
+ #if SA == 1 && SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm);
+ #elif SA == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm);
+ #elif SB == 1
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm);
+ #else
+ XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta);
+ #endif
+}
+
+#endif
// =================================================================================================
// End of the C++11 raw string literal
diff --git a/src/kernels/level3/xgemm_direct_batched.opencl b/src/kernels/level3/xgemm_direct_batched.opencl
index d946a056..d15ed31e 100644
--- a/src/kernels/level3/xgemm_direct_batched.opencl
+++ b/src/kernels/level3/xgemm_direct_batched.opencl
@@ -17,15 +17,16 @@
R"(
// =================================================================================================
+#if defined(ROUTINE_GEMMBATCHED)
// Direct version of the batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
void XgemmDirectBatchedNN(const int kSizeM, const int kSizeN, const int kSizeK,
- const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
- const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
- const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
- __global real* cgm, const __constant int* c_offsets, const int c_ld,
- const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
+ const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
+ const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
+ __global real* cgm, const __constant int* c_offsets, const int c_ld,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
@@ -42,11 +43,11 @@ void XgemmDirectBatchedNN(const int kSizeM, const int kSizeN, const int kSizeK,
// Direct version of the batched GEMM kernel with [A, B] = [non-transposed, transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
void XgemmDirectBatchedNT(const int kSizeM, const int kSizeN, const int kSizeK,
- const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
- const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
- const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
- __global real* cgm, const __constant int* c_offsets, const int c_ld,
- const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
+ const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
+ const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
+ __global real* cgm, const __constant int* c_offsets, const int c_ld,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
@@ -63,11 +64,11 @@ void XgemmDirectBatchedNT(const int kSizeM, const int kSizeN, const int kSizeK,
// Direct version of the batched GEMM kernel with [A, B] = [transposed, non-transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
void XgemmDirectBatchedTN(const int kSizeM, const int kSizeN, const int kSizeK,
- const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
- const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
- const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
- __global real* cgm, const __constant int* c_offsets, const int c_ld,
- const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
+ const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
+ const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
+ __global real* cgm, const __constant int* c_offsets, const int c_ld,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
@@ -84,11 +85,11 @@ void XgemmDirectBatchedTN(const int kSizeM, const int kSizeN, const int kSizeK,
// Direct version of the batched GEMM kernel with [A, B] = [transposed, transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
- const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
- const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
- const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
- __global real* cgm, const __constant int* c_offsets, const int c_ld,
- const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
+ const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
+ const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
+ __global real* cgm, const __constant int* c_offsets, const int c_ld,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
@@ -102,6 +103,87 @@ void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
alm, blm, 1, 1, c_transpose, a_conjugate, b_conjugate);
}
+#endif
+// =================================================================================================
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
+
+// Direct version of the strided-batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void XgemmDirectStridedBatchedNN(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real_arg arg_alpha, const real_arg arg_beta,
+ const __global realMD* restrict agm, const int a_offset, const int a_ld, const int a_stride,
+ const __global realND* restrict bgm, const int b_offset, const int b_ld, const int b_stride,
+ __global real* cgm, const int c_offset, const int c_ld, const int c_stride,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const int batch = get_group_id(2);
+ const int a_offset_batch = a_offset + a_stride * batch;
+ const int b_offset_batch = b_offset + b_stride * batch;
+ const int c_offset_batch = c_offset + c_stride * batch;
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+ XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
+ agm, a_offset_batch, a_ld, bgm, b_offset_batch, b_ld, cgm, c_offset_batch, c_ld,
+ alm, blm, 0, 0, c_transpose, a_conjugate, b_conjugate);
+}
+
+// Direct version of the strided-batched GEMM kernel with [A, B] = [non-transposed, transposed]
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void XgemmDirectStridedBatchedNT(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real_arg arg_alpha, const real_arg arg_beta,
+ const __global realMD* restrict agm, const int a_offset, const int a_ld, const int a_stride,
+ const __global realND* restrict bgm, const int b_offset, const int b_ld, const int b_stride,
+ __global real* cgm, const int c_offset, const int c_ld, const int c_stride,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const int batch = get_group_id(2);
+ const int a_offset_batch = a_offset + a_stride * batch;
+ const int b_offset_batch = b_offset + b_stride * batch;
+ const int c_offset_batch = c_offset + c_stride * batch;
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+ XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
+ agm, a_offset_batch, a_ld, bgm, b_offset_batch, b_ld, cgm, c_offset_batch, c_ld,
+ alm, blm, 0, 1, c_transpose, a_conjugate, b_conjugate);
+}
+
+// Direct version of the strided-batched GEMM kernel with [A, B] = [transposed, non-transposed]
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void XgemmDirectStridedBatchedTN(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real_arg arg_alpha, const real_arg arg_beta,
+ const __global realMD* restrict agm, const int a_offset, const int a_ld, const int a_stride,
+ const __global realND* restrict bgm, const int b_offset, const int b_ld, const int b_stride,
+ __global real* cgm, const int c_offset, const int c_ld, const int c_stride,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const int batch = get_group_id(2);
+ const int a_offset_batch = a_offset + a_stride * batch;
+ const int b_offset_batch = b_offset + b_stride * batch;
+ const int c_offset_batch = c_offset + c_stride * batch;
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+ XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
+ agm, a_offset_batch, a_ld, bgm, b_offset_batch, b_ld, cgm, c_offset_batch, c_ld,
+ alm, blm, 1, 0, c_transpose, a_conjugate, b_conjugate);
+}
+
+// Direct version of the strided-batched GEMM kernel with [A, B] = [transposed, transposed]
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void XgemmDirectStridedBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
+ const real_arg arg_alpha, const real_arg arg_beta,
+ const __global realMD* restrict agm, const int a_offset, const int a_ld, const int a_stride,
+ const __global realND* restrict bgm, const int b_offset, const int b_ld, const int b_stride,
+ __global real* cgm, const int c_offset, const int c_ld, const int c_stride,
+ const int c_transpose, const int a_conjugate, const int b_conjugate) {
+ const int batch = get_group_id(2);
+ const int a_offset_batch = a_offset + a_stride * batch;
+ const int b_offset_batch = b_offset + b_stride * batch;
+ const int c_offset_batch = c_offset + c_stride * batch;
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+ XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
+ agm, a_offset_batch, a_ld, bgm, b_offset_batch, b_ld, cgm, c_offset_batch, c_ld,
+ alm, blm, 1, 1, c_transpose, a_conjugate, b_conjugate);
+}
+
+#endif
// =================================================================================================
// End of the C++11 raw string literal