summaryrefslogtreecommitdiff
path: root/src/kernels/levelx/col2im.opencl
blob: 44908ca17ea83a3eddb5451af81c865b70452b33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// This file contains the col2im kernel, taken from:
// https://gist.github.com/vbkaisetsu/a98299df827f9a5245635f646c1d94be
// Credits go to https://github.com/vbkaisetsu
//
// =================================================================================================

// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
// literal). Comment-out this line for syntax-highlighting when developing.
R"(

// Work-group size parameters re-used from the 'copy' kernel
#ifndef COPY_DIMX
  #define COPY_DIMX 8      // Local workgroup size in the first dimension (w)
#endif
#ifndef COPY_DIMY
  #define COPY_DIMY 8      // Local workgroup size in the second dimension (h)
#endif

// =================================================================================================

inline int grid_ceil(const int x, const int step) {
  return x > 0 ? ((x - 1) / step + 1) * step : x / step * step;
}

__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
void col2im(const int input_h, const int input_w, const int channels,
            const int output_h, const int output_w,
            const int kernel_h, const int kernel_w,
            const int pad_h, const int pad_w,
            const int stride_h, const int stride_w,
            const int dilation_h, const int dilation_w,
            const int stride_bez_h, const int stride_bez_w,
            const int dilation_bez_h, const int dilation_bez_w,
            const int gcd_h, const int gcd_w,
            const __global real* restrict col_buffer, const int col_offset,
            __global real* im_buffer, const int im_offset) {

  const int input_h_scaled = (input_h - 1) / gcd_h + 1;

  // Thread IDs
  const int gcd_scale_w = get_global_id(0) + (pad_w - 1) / gcd_w + 1;
  const int gcd_scale_h = ((int) get_global_id(1)) % input_h_scaled + (pad_h - 1) / gcd_h + 1;
  const int c_id = ((int) get_global_id(1)) / input_h_scaled;

  const int w_index = gcd_scale_w * gcd_w - pad_w;
  const int h_index = gcd_scale_h * gcd_h - pad_h;
  const int th_step = stride_h * dilation_h / gcd_h;
  const int th_begin = grid_ceil(max(-stride_bez_h * gcd_scale_h * stride_h,
                                     (dilation_bez_h * gcd_scale_h - kernel_h + 1) * dilation_h),
                                 th_step);
  const int th_end = min((output_h - stride_bez_h * gcd_scale_h) * stride_h,
                         (dilation_bez_h * gcd_scale_h + 1) * dilation_h);
  const int tw_step = stride_w * dilation_w / gcd_w;
  const int tw_begin = grid_ceil(max(-stride_bez_w * gcd_scale_w * stride_w,
                                     (dilation_bez_w * gcd_scale_w - kernel_w + 1) * dilation_w),
                                 tw_step);
  const int tw_end = min((output_w - stride_bez_w * gcd_scale_w) * stride_w,
                         (dilation_bez_w * gcd_scale_w + 1) * dilation_w);
  if (w_index < input_w && c_id < channels) {
    real val;
    SetToZero(val);
    for (int th = th_begin; th < th_end; th += th_step) {
      for (int tw = tw_begin; tw < tw_end; tw += tw_step) {
        const int kh_id = -th / dilation_h + dilation_bez_h * gcd_scale_h;
        const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w;
        const int h_id = th / stride_h + stride_bez_h * gcd_scale_h;
        const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w;

        const int kernel_index = kw_id + kernel_w * kh_id;
        const int patch_index = w_id + output_w * h_id;
        const int output_index = patch_index + kernel_index * output_w * output_h +
                                 c_id * output_w * output_h * kernel_h * kernel_w;
        Add(val, val, col_buffer[output_index + col_offset]);
      }
    }

    // Sets the input value
    const int input_index = w_index + input_w * (h_index + input_h * c_id);
    im_buffer[input_index + im_offset] = val;
  }
}

// =================================================================================================

// End of the C++11 raw string literal
)"

// =================================================================================================