summaryrefslogtreecommitdiff
path: root/src/kernels/levelx
diff options
context:
space:
mode:
authorKoichi Akabe <vbkaisetsu@gmail.com>2018-12-18 13:56:00 +0900
committerKoichi Akabe <vbkaisetsu@gmail.com>2018-12-18 13:56:00 +0900
commit301dc280dfe75ff3c8b219f64aea83a6bf2f0c8d (patch)
treee47fd45f74bc0dd326e6b120af341861b47f90aa /src/kernels/levelx
parent9819957768174dbb4929b970718a0d6018520979 (diff)
Fix xconvgemm kernel and enable ConvGemmMethod::kSingleKernel
Diffstat (limited to 'src/kernels/levelx')
-rw-r--r--src/kernels/levelx/xconvgemm_part1.opencl33
-rw-r--r--src/kernels/levelx/xconvgemm_part2.opencl94
2 files changed, 99 insertions, 28 deletions
diff --git a/src/kernels/levelx/xconvgemm_part1.opencl b/src/kernels/levelx/xconvgemm_part1.opencl
index abdb5324..25ccba51 100644
--- a/src/kernels/levelx/xconvgemm_part1.opencl
+++ b/src/kernels/levelx/xconvgemm_part1.opencl
@@ -11,7 +11,6 @@
// uses parameters from the direct GEMM kernel. This is the part with the loads from memory (1/2).
// This uses "CONVGEMM_WITH_IM2COL" as a switch to select between direct convgemm or first running
// the im2col kernel to create a 'col' temporary matrix.
-// TODO: Currently only works with 'CONVGEMM_WITH_IM2COL' set
//
// =================================================================================================
@@ -30,12 +29,17 @@ INLINE_FUNC real GlobalToPrivateCheckedImage(const __global real* restrict image
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w) {
+ const int dilation_h, const int dilation_w,
+ const bool kernel_flip) {
// Im2col indices
const int kernel_2d_index = kwg % (kernel_h * kernel_w);
- const int kw_id = kernel_2d_index % kernel_w;
- const int kh_id = kernel_2d_index / kernel_w;
+ const int kw_id = (kernel_flip)
+ ? kernel_w - kernel_2d_index % kernel_w - 1
+ : kernel_2d_index % kernel_w;
+ const int kh_id = (kernel_flip)
+ ? kernel_h - kernel_2d_index / kernel_w - 1
+ : kernel_2d_index / kernel_w;
const int c_id = kwg / (kernel_h * kernel_w);
const int h_index = -pad_h + kh_id * dilation_h + stride_h * h_id;
const int w_index = -pad_w + kw_id * dilation_w + stride_w * w_id;
@@ -55,14 +59,15 @@ INLINE_FUNC real GlobalToPrivateCheckedImage(const __global real* restrict image
// Loads global off-chip memory into local (shared) memory on-chip. This function is specific for
// loading the image input tensor. This includes a bounds check.
-INLINE_FUNC real GlobalToLocalCheckedImage(const __global realMD* restrict imagegm, LOCAL_PTR real* alm,
+INLINE_FUNC real GlobalToLocalCheckedImage(const __global real* restrict imagegm, LOCAL_PTR real* alm,
const int image_offset_batch,
- const int h_id, const int w_id, const int kwg,
+ const int output_w, const int kwg,
const int input_h, const int input_w, const int channels,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w) {
+ const int dilation_h, const int dilation_w,
+ const bool kernel_flip) {
#if MDIMCD == MDIMAD
const int la0 = get_local_id(0);
const int la1 = get_local_id(1);
@@ -82,10 +87,17 @@ INLINE_FUNC real GlobalToLocalCheckedImage(const __global realMD* restrict image
int idm = mg + GetGroupID0()*WGD;
int idk = kg + kwg;
+ const int w_id = idm % output_w;
+ const int h_id = idm / output_w;
+
// Im2col indices
const int kernel_2d_index = idk % (kernel_h * kernel_w);
- const int kw_id = kernel_2d_index % kernel_w;
- const int kh_id = kernel_2d_index / kernel_w;
+ const int kw_id = (kernel_flip)
+ ? kernel_w - kernel_2d_index % kernel_w - 1
+ : kernel_2d_index % kernel_w;
+ const int kh_id = (kernel_flip)
+ ? kernel_h - kernel_2d_index / kernel_w - 1
+ : kernel_2d_index / kernel_w;
const int c_id = idk / (kernel_h * kernel_w);
const int h_index = -pad_h + kh_id * dilation_h + stride_h * h_id;
const int w_index = -pad_w + kw_id * dilation_w + stride_w * w_id;
@@ -104,7 +116,8 @@ INLINE_FUNC real GlobalToLocalCheckedImage(const __global realMD* restrict image
}
}
-#endif
+#endif // defined(ROUTINE_CONVGEMM) && !defined(CONVGEMM_WITH_IM2COL)
+
// =================================================================================================
// End of the C++11 raw string literal
diff --git a/src/kernels/levelx/xconvgemm_part2.opencl b/src/kernels/levelx/xconvgemm_part2.opencl
index e0ac24a0..693cb120 100644
--- a/src/kernels/levelx/xconvgemm_part2.opencl
+++ b/src/kernels/levelx/xconvgemm_part2.opencl
@@ -11,7 +11,6 @@
// uses parameters from the direct GEMM kernel. This part contains the main kernel (2/2).
// This uses "CONVGEMM_WITH_IM2COL" as a switch to select between direct convgemm or first running
// the im2col kernel to create a 'col' temporary matrix.
-// TODO: Currently only works with 'CONVGEMM_WITH_IM2COL' set
//
// =================================================================================================
@@ -23,20 +22,25 @@ R"(
#if defined(ROUTINE_CONVGEMM)
// ConvGEMM kernel
+#if defined(CONVGEMM_WITH_IM2COL)
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size,
const __global realND* restrict kernelgm, const int kernel_offset,
__global real* resultgm, const int result_offset, const int result_stride,
-#if defined(CONVGEMM_WITH_IM2COL)
const __global realMD* restrict colgm, const int col_offset, const int col_stride)
#else
- const __global realMD* restrict imagegm, const int image_offset,
- const int input_h, const int input_w, const int channels,
- const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- const int output_h, const int output_w)
+INLINE_FUNC void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size,
+ const __global realND* restrict kernelgm, const int kernel_offset,
+ __global real* resultgm, const int result_offset, const int result_stride,
+ const __global realMD* restrict imagegm, const int image_offset,
+ const int input_h, const int input_w, const int channels,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int output_h, const int output_w,
+ LOCAL_PTR real* alm, LOCAL_PTR real* blm,
+ const bool kernel_flip)
#endif
{
@@ -49,12 +53,16 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
#endif
const int result_offset_batch = result_offset + result_stride * batch;
+#if defined(CONVGEMM_WITH_IM2COL)
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
+#endif
// Extra pointers to scalar versions of global memory
#if defined(CONVGEMM_WITH_IM2COL)
const __global real* restrict colgms = (const __global real* restrict) colgm;
+ #else
+ const __global real* restrict imagegms = (const __global real* restrict) imagegm;
#endif
const __global real* restrict kernelgms = (const __global real* restrict) kernelgm;
@@ -100,10 +108,10 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
GlobalToLocalScalarA(colgms, alm, num_patches, col_offset_batch, kwg, false, false);
}
#else
- GlobalToLocalCheckedImage(imagegm, alm, image_offset_batch, h_id, w_id, kwg,
+ GlobalToLocalCheckedImage(imagegms, alm, image_offset_batch, output_w, kwg,
input_h, input_w, channels, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
- dilation_h, dilation_w);
+ dilation_h, dilation_w, kernel_flip);
#endif
if (patch_size % VWND == 0 && kernel_offset % VWND == 0) {
GlobalToLocalDirectB(kernelgm, blm, patch_size, kernel_offset, kwg, true, false);
@@ -151,10 +159,12 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
#if defined(CONVGEMM_WITH_IM2COL)
apd[_mi] = GlobalToPrivateDirectA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false);
#else
- apd[_mi] = GlobalToPrivateCheckedImage(imagegm, image_offset_batch, h_id, w_id, kwg,
+ const int w_id = (idm + _mi) % output_w;
+ const int h_id = (idm + _mi) / output_w;
+ apd[_mi] = GlobalToPrivateCheckedImage(imagegms, image_offset_batch, h_id, w_id, kwg,
input_h, input_w, channels, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
- dilation_h, dilation_w);
+ dilation_h, dilation_w, kernel_flip);
#endif
}
#pragma unroll
@@ -193,10 +203,10 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
#if defined(CONVGEMM_WITH_IM2COL)
GlobalToLocalCheckedA(colgms, alm, num_patches, col_offset_batch, kwg, false, false, num_patches, patch_size);
#else
- GlobalToLocalCheckedImage(imagegm, alm, image_offset_batch, h_id, w_id, kwg,
+ GlobalToLocalCheckedImage(imagegms, alm, image_offset_batch, output_w, kwg,
input_h, input_w, channels, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
- dilation_h, dilation_w);
+ dilation_h, dilation_w, kernel_flip);
#endif
GlobalToLocalCheckedB(kernelgms, blm, patch_size, kernel_offset, kwg, true, false, num_kernels, patch_size);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -239,10 +249,12 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
#if defined(CONVGEMM_WITH_IM2COL)
apd[_mi] = GlobalToPrivateCheckedA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false, num_patches);
#else
- apd[_mi] = GlobalToPrivateCheckedImage(imagegm, image_offset_batch, h_id, w_id, kwg,
+ const int w_id = (idm + _mi) % output_w;
+ const int h_id = (idm + _mi) / output_w;
+ apd[_mi] = GlobalToPrivateCheckedImage(imagegms, image_offset_batch, h_id, w_id, kwg,
input_h, input_w, channels, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
- dilation_h, dilation_w);
+ dilation_h, dilation_w, kernel_flip);
#endif
}
#pragma unroll
@@ -272,7 +284,53 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
}
}
-#endif
+#if !defined(CONVGEMM_WITH_IM2COL)
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void XconvgemmFlip(const int num_patches, const int num_kernels, const int patch_size,
+ const __global realND* restrict kernelgm, const int kernel_offset,
+ __global real* resultgm, const int result_offset, const int result_stride,
+ const __global realMD* restrict imagegm, const int image_offset,
+ const int input_h, const int input_w, const int channels,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int output_h, const int output_w) {
+ const bool kernel_flip = true;
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+ Xconvgemm(num_patches, num_kernels, patch_size,
+ kernelgm, kernel_offset, resultgm, result_offset, result_stride,
+ imagegm, image_offset, input_h, input_w, channels, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ output_h, output_w, alm, blm, kernel_flip);
+}
+
+__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
+void XconvgemmNormal(const int num_patches, const int num_kernels, const int patch_size,
+ const __global realND* restrict kernelgm, const int kernel_offset,
+ __global real* resultgm, const int result_offset, const int result_stride,
+ const __global realMD* restrict imagegm, const int image_offset,
+ const int input_h, const int input_w, const int channels,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int output_h, const int output_w) {
+ const bool kernel_flip = false;
+ __local real alm[WGD * (WGD + PADA)];
+ __local real blm[WGD * (WGD + PADB)];
+ Xconvgemm(num_patches, num_kernels, patch_size,
+ kernelgm, kernel_offset, resultgm, result_offset, result_stride,
+ imagegm, image_offset, input_h, input_w, channels, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ output_h, output_w, alm, blm, kernel_flip);
+}
+
+#endif // !defined(CONVGEMM_WITH_IM2COL)
+
+#endif // defined(ROUTINE_CONVGEMM)
+
// =================================================================================================
// End of the C++11 raw string literal