summaryrefslogtreecommitdiff
path: root/src/kernels/levelx/im2col.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/levelx/im2col.opencl')
-rw-r--r--src/kernels/levelx/im2col.opencl59
1 files changed, 49 insertions, 10 deletions
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl
index 301e076b..5db4cb5f 100644
--- a/src/kernels/levelx/im2col.opencl
+++ b/src/kernels/levelx/im2col.opencl
@@ -25,15 +25,16 @@ R"(
// =================================================================================================
-__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void im2col(const int input_h, const int input_w, const int channels,
- const int output_h, const int output_w,
- const int kernel_h, const int kernel_w,
- const int pad_h, const int pad_w,
- const int stride_h, const int stride_w,
- const int dilation_h, const int dilation_w,
- const __global real* restrict im_buffer, const int im_offset,
- __global real* col_buffer, const int col_offset) {
+// Main body of the kernel
+INLINE_FUNC void Xim2col(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const bool kernel_flip,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
// Thread IDs
const int w_id = get_global_id(0); // image width, max 'output_w'
@@ -58,7 +59,9 @@ void im2col(const int input_h, const int input_w, const int channels,
}
// Sets the output value
- const int kernel_index = kw_id + kernel_w * kh_id;
+ const int kernel_index = (kernel_flip)
+ ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1
+ : kw_id + kernel_w * kh_id;
const int patch_index = w_id + output_w * h_id;
const int output_index = patch_index + kernel_index * output_w * output_h +
c_id * output_w * output_h * kernel_h * kernel_w;
@@ -70,6 +73,42 @@ void im2col(const int input_h, const int input_w, const int channels,
// =================================================================================================
+// Kernel flip version of the Xim2col kernel (for convolution)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xim2colKernelFlip(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
+ const bool kernel_flip = true;
+ Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ kernel_flip,
+ im_buffer, im_offset, col_buffer, col_offset);
+}
+
+// Normal version of the Xim2col kernel (for cross-correlation)
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void Xim2colKernelNormal(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const __global real* restrict im_buffer, const int im_offset,
+ __global real* col_buffer, const int col_offset) {
+ const bool kernel_flip = false;
+ Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ kernel_flip,
+ im_buffer, im_offset, col_buffer, col_offset);
+}
+
+// =================================================================================================
+
// End of the C++11 raw string literal
)"