diff options
Diffstat (limited to 'src/kernels/levelx/col2im.opencl')
-rw-r--r-- | src/kernels/levelx/col2im.opencl | 73 |
1 files changed, 60 insertions, 13 deletions
diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl index a37db24f..484a7a98 100644 --- a/src/kernels/levelx/col2im.opencl +++ b/src/kernels/levelx/col2im.opencl @@ -28,18 +28,20 @@ inline int grid_ceil(const int x, const int step) { return x > 0 ? ((x - 1) / step + 1) * step : x / step * step; } +// Main body of the kernel __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void col2im(const int input_h, const int input_w, const int channels, - const int output_h, const int output_w, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int stride_bez_h, const int stride_bez_w, - const int dilation_bez_h, const int dilation_bez_w, - const int gcd_h, const int gcd_w, - const __global real* restrict col_buffer, const int col_offset, - __global real* im_buffer, const int im_offset) { +INLINE_FUNC void Xcol2im(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const bool kernel_flip, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { const int input_h_scaled = (input_h - 1) / gcd_h + 1; @@ -71,8 +73,9 @@ void col2im(const int input_h, const int input_w, const int channels, const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w; const int h_id = th / stride_h + stride_bez_h * gcd_scale_h; const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w; - - const int kernel_index = kw_id + kernel_w * kh_id; + const int kernel_index = (kernel_flip) + ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1 + : kw_id + kernel_w * kh_id; const int patch_index = w_id + output_w * h_id; const int output_index = patch_index + kernel_index * output_w * output_h + c_id * output_w * output_h * kernel_h * kernel_w; @@ -89,6 +92,50 @@ void col2im(const int input_h, const int input_w, const int channels, // ================================================================================================= +// Kernel flip version of the Xcol2im kernel (for convolution) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xcol2imKernelFlip(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { + const bool kernel_flip = true; + Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w, + kernel_flip, + col_buffer, col_offset, im_buffer, im_offset); +} + +// Normal version of the Xcol2im kernel (for cross-correlation) +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void Xcol2imKernelNormal(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, + const __global real* restrict col_buffer, const int col_offset, + __global real* im_buffer, const int im_offset) { + const bool kernel_flip = false; + Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w, + kernel_flip, + col_buffer, col_offset, im_buffer, im_offset); +} + +// ================================================================================================= + // End of the C++11 raw string literal )" |