diff options
Diffstat (limited to 'src/kernels/levelx/im2col.opencl')
-rw-r--r-- | src/kernels/levelx/im2col.opencl | 59 |
1 files changed, 49 insertions, 10 deletions
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl index 301e076b..5db4cb5f 100644 --- a/src/kernels/levelx/im2col.opencl +++ b/src/kernels/levelx/im2col.opencl @@ -25,15 +25,16 @@ R"( // ================================================================================================= -__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void im2col(const int input_h, const int input_w, const int channels, - const int output_h, const int output_w, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const __global real* restrict im_buffer, const int im_offset, - __global real* col_buffer, const int col_offset) { +// Main body of the kernel +INLINE_FUNC void Xim2col(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const bool kernel_flip, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { // Thread IDs const int w_id = get_global_id(0); // image width, max 'output_w' @@ -58,7 +59,9 @@ void im2col(const int input_h, const int input_w, const int channels, } // Sets the output value - const int kernel_index = kw_id + kernel_w * kh_id; + const int kernel_index = (kernel_flip) + ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1 + : kw_id + kernel_w * kh_id; const int patch_index = w_id + output_w * h_id; const int output_index = patch_index + kernel_index * output_w * output_h + c_id * output_w * output_h * kernel_h * kernel_w; @@ -70,6 +73,42 @@ void im2col(const int input_h, const int input_w, const int channels, // ================================================================================================= +// Kernel flip version of the Xim2col kernel (for convolution) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xim2colKernelFlip(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { + const bool kernel_flip = true; + Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + kernel_flip, + im_buffer, im_offset, col_buffer, col_offset); +} + +// Normal version of the Xim2col kernel (for cross-correlation) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xim2colKernelNormal(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict im_buffer, const int im_offset, + __global real* col_buffer, const int col_offset) { + const bool kernel_flip = false; + Xim2col(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + kernel_flip, + im_buffer, im_offset, col_buffer, col_offset); +} + +// ================================================================================================= + // End of the C++11 raw string literal )" |