summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-05-13 22:10:06 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-05-13 22:10:21 +0200
commit0cb95800424273e56740d3b23cef53e740eab9b5 (patch)
tree223405bb69ed1a59e8c78465336666d5e3761c66 /src
parentad8f1027abf6fecbf8f119172644e729a8a94d0c (diff)
Created a dedicated convgemm GEMM kernel as a copy of the batched direct gemm kernel
Diffstat (limited to 'src')
-rw-r--r--src/kernels/level3/xconvgemm.opencl228
-rw-r--r--src/kernels/level3/xgemm_direct_batched.opencl2
-rw-r--r--src/routines/levelx/xconvgemm.cpp69
3 files changed, 250 insertions, 49 deletions
diff --git a/src/kernels/level3/xconvgemm.opencl b/src/kernels/level3/xconvgemm.opencl
new file mode 100644
index 00000000..d3c53d7d
--- /dev/null
+++ b/src/kernels/level3/xconvgemm.opencl
@@ -0,0 +1,228 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file contains the an implementation of 3D convolution on a 4D image using GEMM kernels. It
+// uses parameters from the direct GEMM kernel.
+//
+// =================================================================================================
+
+// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
+// literal). Comment-out this line for syntax-highlighting when developing.
+R"(
+
+// =================================================================================================
+#if defined(ROUTINE_CONVGEMM)
+
+// ConvGEMM kernel
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size,
+ const __global realMD* restrict colgm, const int col_offset, const int col_stride,
+ const __global realND* restrict kernelgm, const int kernel_offset,
+ __global real* resultgm, const int result_offset, const int result_stride) {
+
+ // Batch offsets
+ const int batch = get_group_id(2);
+ const int col_offset_batch = col_offset + col_stride * batch;
+ const int result_offset_batch = result_offset + result_stride * batch;
+
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+
+ // Extra pointers to scalar versions of global memory
+ const __global real* restrict colgms = (const __global real* restrict) colgm;
+ const __global real* restrict kernelgms = (const __global real* restrict) kernelgm;
+
+ // Allocates workitem-private memory (registers)
+ #pragma promote_to_registers
+ real apd[MWID];
+ #pragma promote_to_registers
+ real bpd[NWID];
+ #pragma promote_to_registers
+ real cpd[NWID * MWID];
+
+ // Initializes the accumulation registers
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ SetToZero(cpd[_ni * MWID + _mi]);
+ }
+ }
+
+ // The faster version of GEMM is not allowed on the (incomplete) borders. Therefore, this section
+ // processes only the main parts: output blocks of WGD by WGD.
+ const int idm = get_local_id(0) * MWID + GetGroupID0() * WGD;
+ const int idn = get_local_id(1) * NWID + GetGroupID1() * WGD;
+ if ((idm < (num_patches/WGD)*WGD) && (idn < (num_kernels/WGD)*WGD)) {
+
+ // Loops over all complete workgroup tiles (K-dimension)
+ int kwg = 0;
+ for (; kwg < (patch_size/WGD) * WGD; kwg += WGD) {
+
+ // Loads data: off-chip --> local (matrix A and B)
+ if (num_patches % VWMD == 0 && col_offset_batch % VWMD == 0) {
+ GlobalToLocalDirectA(colgm, alm, num_patches, col_offset_batch, kwg, false, false);
+ }
+ else {
+ GlobalToLocalScalarA(colgms, alm, num_patches, col_offset_batch, kwg, false, false);
+ }
+ if (patch_size % VWND == 0 && kernel_offset % VWND == 0) {
+ GlobalToLocalDirectB(kernelgm, blm, patch_size, kernel_offset, kwg, true, false);
+ }
+ else {
+ GlobalToLocalScalarB(kernelgms, blm, patch_size, kernel_offset, kwg, true, false);
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ // Loops over all workitem tiles, unrolled by a factor KWID
+ for (int pwi = 0; pwi < WGD; pwi += KWID) {
+ #pragma unroll
+ for (int _pit = 0; _pit < KWID; _pit += 1) {
+ int kg = pwi + _pit;
+
+ // Loads data: local --> private (matrix A and B)
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ apd[_mi] = LocalToPrivateDirectA(alm, _mi, kg, false);
+ }
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ bpd[_ni] = LocalToPrivateDirectB(blm, _ni, kg, true);
+ }
+
+ // Performs the accumulation (Cpmd += Apmd * Bpmd)
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
+ }
+ }
+ }
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ }
+
+ // Loop over the remaining part (incomplete tile in K-dimension)
+ for (; kwg < patch_size; ++kwg) {
+
+ // Loads data: off-chip --> private (matrix A and B)
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ apd[_mi] = GlobalToPrivateDirectA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false);
+ }
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ bpd[_ni] = GlobalToPrivateDirectB(kernelgms, _ni, patch_size, kernel_offset, idn, kwg, true, false);
+ }
+
+ // Performs the accumulation (Cpmd += Apmd * Bpmd)
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
+ }
+ }
+ }
+
+ // Stores a tile of results
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ StoreResultsDirect(resultgm, cpd[_ni * MWID + _mi], _mi, _ni, idm, idn,
+ ONE, ZERO, num_patches, result_offset_batch, false);
+ }
+ }
+ }
+
+ // Simple but slower version for the parts on the edge (incomplete tiles in M and N-dimensions)
+ else {
+
+ // Loops over all complete workgroup tiles (K-dimension)
+ int kwg = 0;
+ for (; kwg < (patch_size/WGD) * WGD; kwg+=WGD) {
+
+ // Loads data: off-chip --> local (matrix A and B)
+ GlobalToLocalCheckedA(colgms, alm, num_patches, col_offset_batch, kwg, false, false, num_patches, patch_size);
+ GlobalToLocalCheckedB(kernelgms, blm, patch_size, kernel_offset, kwg, true, false, num_kernels, patch_size);
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ // Loops over all workitem tiles, unrolled by a factor KWID
+ for (int pwi = 0; pwi < WGD; pwi += KWID) {
+ #pragma unroll
+ for (int _pit = 0; _pit < KWID; _pit += 1) {
+ int kg = pwi + _pit;
+
+ // Loads data: local --> private (matrix A and B)
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ apd[_mi] = LocalToPrivateDirectA(alm, _mi, kg, false);
+ }
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ bpd[_ni] = LocalToPrivateDirectB(blm, _ni, kg, true);
+ }
+
+ // Performs the accumulation (C += A * B)
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
+ }
+ }
+ }
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ }
+
+ // Loop over the remaining part (incomplete tile in K-dimension)
+ for (; kwg < patch_size; ++kwg) {
+
+ // Loads data: off-chip --> private (matrix A and B)
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ apd[_mi] = GlobalToPrivateCheckedA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false, num_patches);
+ }
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ bpd[_ni] = GlobalToPrivateCheckedB(kernelgms, _ni, patch_size, kernel_offset, idn, kwg, true, false, num_kernels);
+ }
+
+ // Performs the accumulation (C += A * B)
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
+ }
+ }
+ }
+
+ // Stores a tile of results
+ #pragma unroll
+ for (int _ni = 0; _ni < NWID; _ni += 1) {
+ #pragma unroll
+ for (int _mi = 0; _mi < MWID; _mi += 1) {
+ StoreResultsChecked(resultgm, cpd[_ni * MWID + _mi], _mi, _ni, idm, idn, num_patches, num_kernels,
+ ONE, ZERO, num_patches, result_offset_batch, false);
+ }
+ }
+ }
+}
+
+#endif
+// =================================================================================================
+
+// End of the C++11 raw string literal
+)"
+
+// =================================================================================================
diff --git a/src/kernels/level3/xgemm_direct_batched.opencl b/src/kernels/level3/xgemm_direct_batched.opencl
index aebbb5ad..d15ed31e 100644
--- a/src/kernels/level3/xgemm_direct_batched.opencl
+++ b/src/kernels/level3/xgemm_direct_batched.opencl
@@ -105,7 +105,7 @@ void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// =================================================================================================
-#if defined(ROUTINE_GEMMSTRIDEDBATCHED) || defined(ROUTINE_CONVGEMM)
+#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
// Direct version of the strided-batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index f36f72e5..8cb8093c 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -11,13 +11,12 @@
//
// =================================================================================================
-#include "routines/levelx/xconvgemm.hpp"
-#include "routines/levelx/xim2col.hpp"
-#include "routines/level3/xgemm.hpp"
-
#include <string>
#include <vector>
+#include "routines/levelx/xconvgemm.hpp"
+#include "routines/levelx/xim2col.hpp"
+
namespace clblast {
// =================================================================================================
@@ -32,7 +31,7 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam
#include "../../kernels/level3/xgemm_direct_part2.opencl"
#include "../../kernels/level3/xgemm_direct_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
- #include "../../kernels/level3/xgemm_direct_batched.opencl"
+ #include "../../kernels/level3/xconvgemm.opencl"
}) {
}
@@ -93,61 +92,35 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
}
// Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
- const auto m = num_patches;
- const auto n = num_kernels;
- const auto k = patch_size;
- const auto col_ld = m;
- const auto kernel_ld = k;
- const auto result_ld = m;
const auto col_stride = patch_size * num_patches;
- const auto kernel_stride = size_t{0}; // applies the same kernel to all batches
const auto result_stride = num_kernels * output_h * output_w;
- // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
- bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate;
- size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two;
- Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k,
- col_one, col_two, kernel_one, kernel_two, result_one, result_two,
- col_do_transpose, kernel_do_transpose,
- result_do_transpose, col_conjugate, kernel_conjugate, 0);
-
// Tests the matrices for validity
+ TestMatrixB(patch_size, num_kernels, kernel_buffer, kernel_offset, patch_size);
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
- TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld);
- TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld);
- TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld);
+ TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches);
+ TestMatrixC(num_patches, num_kernels, result_buffer, result_offset + result_stride * batch, num_patches);
}
// Retrieves the proper XgemmDirect kernel from the compiled binary
- const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") :
- (kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN");
- auto kernel = Kernel(program_, name);
+ auto kernel = Kernel(program_, "Xconvgemm");
// Sets the kernel arguments
- kernel.SetArgument(0, static_cast<int>(m));
- kernel.SetArgument(1, static_cast<int>(n));
- kernel.SetArgument(2, static_cast<int>(k));
- kernel.SetArgument(3, GetRealArg(ConstantOne<T>()));
- kernel.SetArgument(4, GetRealArg(ConstantZero<T>()));
- kernel.SetArgument(5, col_buffer());
- kernel.SetArgument(6, static_cast<int>(0));
- kernel.SetArgument(7, static_cast<int>(col_ld));
- kernel.SetArgument(8, static_cast<int>(col_stride));
- kernel.SetArgument(9, kernel_buffer());
- kernel.SetArgument(10, static_cast<int>(kernel_offset));
- kernel.SetArgument(11, static_cast<int>(kernel_ld));
- kernel.SetArgument(12, static_cast<int>(kernel_stride));
- kernel.SetArgument(13, result_buffer());
- kernel.SetArgument(14, static_cast<int>(result_offset));
- kernel.SetArgument(15, static_cast<int>(result_ld));
- kernel.SetArgument(16, static_cast<int>(result_stride));
- kernel.SetArgument(17, static_cast<int>(result_do_transpose));
- kernel.SetArgument(18, static_cast<int>(false));
- kernel.SetArgument(19, static_cast<int>(false));
+ kernel.SetArgument(0, static_cast<int>(num_patches));
+ kernel.SetArgument(1, static_cast<int>(num_kernels));
+ kernel.SetArgument(2, static_cast<int>(patch_size));
+ kernel.SetArgument(3, col_buffer());
+ kernel.SetArgument(4, static_cast<int>(0));
+ kernel.SetArgument(5, static_cast<int>(col_stride));
+ kernel.SetArgument(6, kernel_buffer());
+ kernel.SetArgument(7, static_cast<int>(kernel_offset));
+ kernel.SetArgument(8, result_buffer());
+ kernel.SetArgument(9, static_cast<int>(result_offset));
+ kernel.SetArgument(10, static_cast<int>(result_stride));
// Computes the global and local thread sizes
- const auto m_ceiled = Ceil(m, db_["WGD"]);
- const auto n_ceiled = Ceil(n, db_["WGD"]);
+ const auto m_ceiled = Ceil(num_patches, db_["WGD"]);
+ const auto n_ceiled = Ceil(num_kernels, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],