summaryrefslogtreecommitdiff
path: root/src/kernels/levelx/col2im.opencl
blob: ab0ffbfaba7e3229270dce74da8c9dc97b9645ed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// This file contains the col2im kernel, taken from:
// https://gist.github.com/vbkaisetsu/a98299df827f9a5245635f646c1d94be
// Credits go to https://github.com/vbkaisetsu
//
// =================================================================================================

// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
// literal). Comment-out this line for syntax-highlighting when developing.
R"(

// Work-group size parameters re-used from the 'copy' kernel
#ifndef COPY_DIMX
  #define COPY_DIMX 8      // Local workgroup size in the first dimension (w)
#endif
#ifndef COPY_DIMY
  #define COPY_DIMY 8      // Local workgroup size in the second dimension (h)
#endif

// =================================================================================================

inline int grid_ceil(const int x, const int step) {
  return x > 0 ? ((x - 1) / step + 1) * step : x / step * step;
}

// Main body of the kernel
INLINE_FUNC void Xcol2im(const int input_h, const int input_w, const int channels,
                         const int output_h, const int output_w,
                         const int kernel_h, const int kernel_w,
                         const int pad_h, const int pad_w,
                         const int stride_h, const int stride_w,
                         const int dilation_h, const int dilation_w,
                         const int stride_bez_h, const int stride_bez_w,
                         const int dilation_bez_h, const int dilation_bez_w,
                         const int gcd_h, const int gcd_w,
                         const bool kernel_flip,
                         const __global real* restrict col_buffer, const int col_offset,
                         __global real* im_buffer, const int im_offset) {

  const int input_h_scaled = (input_h - 1) / gcd_h + 1;

  // Thread IDs
  const int gcd_scale_w = get_global_id(0) + (pad_w - 1) / gcd_w + 1;
  const int gcd_scale_h = ((int) get_global_id(1)) % input_h_scaled + (pad_h - 1) / gcd_h + 1;
  const int c_id = ((int) get_global_id(1)) / input_h_scaled;

  const int w_index = gcd_scale_w * gcd_w - pad_w;
  const int h_index = gcd_scale_h * gcd_h - pad_h;
  const int th_step = stride_h * dilation_h / gcd_h;
  const int th_begin = grid_ceil(max(-stride_bez_h * gcd_scale_h * stride_h,
                                     (dilation_bez_h * gcd_scale_h - kernel_h + 1) * dilation_h),
                                 th_step);
  const int th_end = min((output_h - stride_bez_h * gcd_scale_h) * stride_h,
                         (dilation_bez_h * gcd_scale_h + 1) * dilation_h);
  const int tw_step = stride_w * dilation_w / gcd_w;
  const int tw_begin = grid_ceil(max(-stride_bez_w * gcd_scale_w * stride_w,
                                     (dilation_bez_w * gcd_scale_w - kernel_w + 1) * dilation_w),
                                 tw_step);
  const int tw_end = min((output_w - stride_bez_w * gcd_scale_w) * stride_w,
                         (dilation_bez_w * gcd_scale_w + 1) * dilation_w);
  if (w_index < input_w && c_id < channels) {
    real val;
    SetToZero(val);
    for (int th = th_begin; th < th_end; th += th_step) {
      for (int tw = tw_begin; tw < tw_end; tw += tw_step) {
        const int kh_id = -th / dilation_h + dilation_bez_h * gcd_scale_h;
        const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w;
        const int h_id = th / stride_h + stride_bez_h * gcd_scale_h;
        const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w;
        const int kernel_index = (kernel_flip)
                               ? kernel_h * kernel_w - kw_id - kernel_w * kh_id - 1
                               : kw_id + kernel_w * kh_id;
        const int patch_index = w_id + output_w * h_id;
        const int output_index = patch_index + kernel_index * output_w * output_h +
                                 c_id * output_w * output_h * kernel_h * kernel_w;
        Add(val, val, col_buffer[output_index + col_offset]);
      }
    }

    // Accumulates the resulting value with the existing im-buffer (+= val)
    const int input_index = w_index + input_w * (h_index + input_h * c_id);
    real im_buffer_value = im_buffer[input_index + im_offset];
    Add(im_buffer[input_index + im_offset], im_buffer_value, val);
  }
}

// =================================================================================================

// Kernel flip version of the Xcol2im kernel (for convolution)
#if RELAX_WORKGROUP_SIZE == 1
  __kernel
#else
  __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
#endif
void Xcol2imKernelFlip(const int input_h, const int input_w, const int channels,
                       const int output_h, const int output_w,
                       const int kernel_h, const int kernel_w,
                       const int pad_h, const int pad_w,
                       const int stride_h, const int stride_w,
                       const int dilation_h, const int dilation_w,
                       const int stride_bez_h, const int stride_bez_w,
                       const int dilation_bez_h, const int dilation_bez_w,
                       const int gcd_h, const int gcd_w,
                       const __global real* restrict col_buffer, const int col_offset,
                       __global real* im_buffer, const int im_offset) {
  const bool kernel_flip = true;
  Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
          pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
          stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w,
          kernel_flip,
          col_buffer, col_offset, im_buffer, im_offset);
}

// Normal version of the Xcol2im kernel (for cross-correlation)
#if RELAX_WORKGROUP_SIZE == 1
  __kernel
#else
  __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
#endif
void Xcol2imKernelNormal(const int input_h, const int input_w, const int channels,
                         const int output_h, const int output_w,
                         const int kernel_h, const int kernel_w,
                         const int pad_h, const int pad_w,
                         const int stride_h, const int stride_w,
                         const int dilation_h, const int dilation_w,
                         const int stride_bez_h, const int stride_bez_w,
                         const int dilation_bez_h, const int dilation_bez_w,
                         const int gcd_h, const int gcd_w,
                         const __global real* restrict col_buffer, const int col_offset,
                         __global real* im_buffer, const int im_offset) {
  const bool kernel_flip = false;
  Xcol2im(input_h, input_w, channels, output_h, output_w, kernel_h, kernel_w,
          pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
          stride_bez_h, stride_bez_w, dilation_bez_h, dilation_bez_w, gcd_h, gcd_w,
          kernel_flip,
          col_buffer, col_offset, im_buffer, im_offset);
}

// =================================================================================================

// End of the C++11 raw string literal
)"

// =================================================================================================