From 0b3d04f70902e00f86c572a5e3c379f9335b216f Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 30 Oct 2018 14:54:55 +0900 Subject: Fix col2im implementation --- src/kernels/levelx/col2im.opencl | 72 +++++++++++++++++++++++++--------------- src/routines/levelx/xcol2im.cpp | 27 +++++++++++---- src/utilities/utilities.cpp | 26 +++++++++++++++ src/utilities/utilities.hpp | 6 ++++ 4 files changed, 99 insertions(+), 32 deletions(-) (limited to 'src') diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl index 76917795..44908ca1 100644 --- a/src/kernels/levelx/col2im.opencl +++ b/src/kernels/levelx/col2im.opencl @@ -24,6 +24,10 @@ R"( // ================================================================================================= +inline int grid_ceil(const int x, const int step) { + return x > 0 ? ((x - 1) / step + 1) * step : x / step * step; +} + __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) void col2im(const int input_h, const int input_w, const int channels, const int output_h, const int output_w, @@ -31,38 +35,54 @@ void col2im(const int input_h, const int input_w, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const int stride_bez_h, const int stride_bez_w, + const int dilation_bez_h, const int dilation_bez_w, + const int gcd_h, const int gcd_w, const __global real* restrict col_buffer, const int col_offset, - __global real *im_buffer, const int im_offset) { - const int x_x = get_global_id(0) + pad_w; - const int x_y = ((int) get_global_id(1)) % input_h + pad_h; - const int channel = ((int) get_global_id(1)) / input_h; - const int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; - const int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; - const int col_channel_shift = channel * kernel_w * kernel_h * output_h * output_w + col_offset; - const int x_channel_shift = channel * input_h * input_w + im_offset; - const int t_y_begin = (x_y < kernel_extent_h) ? 0 : (x_y - kernel_extent_h) / stride_h + 1; - const int t_y_end = min(x_y / stride_h + 1, output_h); - const int t_x_begin = (x_x < kernel_extent_w) ? 0 : (x_x - kernel_extent_w) / stride_w + 1; - const int t_x_end = min(x_x / stride_w + 1, output_w); + __global real* im_buffer, const int im_offset) { + + const int input_h_scaled = (input_h - 1) / gcd_h + 1; - if (x_x < input_w + pad_w && channel < channels) { + // Thread IDs + const int gcd_scale_w = get_global_id(0) + (pad_w - 1) / gcd_w + 1; + const int gcd_scale_h = ((int) get_global_id(1)) % input_h_scaled + (pad_h - 1) / gcd_h + 1; + const int c_id = ((int) get_global_id(1)) / input_h_scaled; + + const int w_index = gcd_scale_w * gcd_w - pad_w; + const int h_index = gcd_scale_h * gcd_h - pad_h; + const int th_step = stride_h * dilation_h / gcd_h; + const int th_begin = grid_ceil(max(-stride_bez_h * gcd_scale_h * stride_h, + (dilation_bez_h * gcd_scale_h - kernel_h + 1) * dilation_h), + th_step); + const int th_end = min((output_h - stride_bez_h * gcd_scale_h) * stride_h, + (dilation_bez_h * gcd_scale_h + 1) * dilation_h); + const int tw_step = stride_w * dilation_w / gcd_w; + const int tw_begin = grid_ceil(max(-stride_bez_w * gcd_scale_w * stride_w, + (dilation_bez_w * gcd_scale_w - kernel_w + 1) * dilation_w), + tw_step); + const int tw_end = min((output_w - stride_bez_w * gcd_scale_w) * stride_w, + (dilation_bez_w * gcd_scale_w + 1) * dilation_w); + if (w_index < input_w && c_id < channels) { real val; SetToZero(val); - for (int t_y = t_y_begin; t_y < t_y_end; ++t_y) { - for (int t_x = t_x_begin; t_x < t_x_end; ++t_x) { - int w_y = x_y - t_y * stride_h; - int w_x = x_x - t_x * stride_w; - if (w_y % dilation_h == 0 && w_x % dilation_w == 0) { - w_y /= dilation_h; - w_x /= dilation_w; - val += col_buffer[col_channel_shift - + (w_x + w_y * kernel_w) * output_h * output_w - + t_y * output_w - + t_x]; - } + for (int th = th_begin; th < th_end; th += th_step) { + for (int tw = tw_begin; tw < tw_end; tw += tw_step) { + const int kh_id = -th / dilation_h + dilation_bez_h * gcd_scale_h; + const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w; + const int h_id = th / stride_h + stride_bez_h * gcd_scale_h; + const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w; + + const int kernel_index = kw_id + kernel_w * kh_id; + const int patch_index = w_id + output_w * h_id; + const int output_index = patch_index + kernel_index * output_w * output_h + + c_id * output_w * output_h * kernel_h * kernel_w; + Add(val, val, col_buffer[output_index + col_offset]); } } - im_buffer[x_channel_shift + (x_y - pad_h) * input_w + x_x - pad_w] = val; + + // Sets the input value + const int input_index = w_index + input_w * (h_index + input_h * c_id); + im_buffer[input_index + im_offset] = val; } } diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp index 8339c02c..7a0c36b7 100644 --- a/src/routines/levelx/xcol2im.cpp +++ b/src/routines/levelx/xcol2im.cpp @@ -49,6 +49,15 @@ void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size const auto padding_w = dilation_w * (kernel_w - 1) + 1; const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; + int stride_bez_h = 0; + int stride_bez_w = 0; + int dilation_bez_h = 0; + int dilation_bez_w = 0; + int gcd_h = 0; + int gcd_w = 0; + EuclidGCD(static_cast(stride_h), static_cast(dilation_h), stride_bez_h, dilation_bez_h, gcd_h); + EuclidGCD(static_cast(stride_w), static_cast(dilation_w), stride_bez_w, dilation_bez_w, gcd_w); + // Retrieves the kernel from the compiled binary auto kernel = Kernel(program_, "col2im"); @@ -66,14 +75,20 @@ void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size kernel.SetArgument(10, static_cast(stride_w)); kernel.SetArgument(11, static_cast(dilation_h)); kernel.SetArgument(12, static_cast(dilation_w)); - kernel.SetArgument(13, col_buffer()); - kernel.SetArgument(14, static_cast(col_offset)); - kernel.SetArgument(15, im_buffer()); - kernel.SetArgument(16, static_cast(im_offset)); + kernel.SetArgument(13, stride_bez_h); + kernel.SetArgument(14, stride_bez_w); + kernel.SetArgument(15, dilation_bez_h); + kernel.SetArgument(16, dilation_bez_w); + kernel.SetArgument(17, gcd_h); + kernel.SetArgument(18, gcd_w); + kernel.SetArgument(19, col_buffer()); + kernel.SetArgument(20, static_cast(col_offset)); + kernel.SetArgument(21, im_buffer()); + kernel.SetArgument(22, static_cast(im_offset)); // Launches the kernel - const auto w_ceiled = Ceil(col_w, db_["COPY_DIMX"]); - const auto h_ceiled = Ceil(col_h, db_["COPY_DIMY"]); + const auto w_ceiled = Ceil((width - 1) / gcd_w + 1, db_["COPY_DIMX"]); + const auto h_ceiled = Ceil((height - 1) / gcd_h + 1, db_["COPY_DIMY"]); const auto global = std::vector{w_ceiled, h_ceiled * channels}; const auto local = std::vector{db_["COPY_DIMX"], db_["COPY_DIMY"]}; RunKernel(kernel, queue_, device_, global, local, event_); diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index a8fdaa19..a6cd82e7 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -488,5 +488,31 @@ std::string GetDeviceName(const Device& device) { return device_name; } +// ================================================================================================= + +// Solve Bezout's identity +// a * p + b * q = r = GCD(a, b) +void EuclidGCD(int a, int b, int &p, int &q, int &r) { + p = 0; + q = 1; + int p_1 = 1; + int q_1 = 0; + for (;;) { + const int c = a % b; + if (c == 0) { + break; + } + const int p_2 = p_1; + const int q_2 = q_1; + p_1 = p; + q_1 = q; + p = p_2 - p_1 * (a / b); + q = q_2 - q_1 * (a / b); + a = b; + b = c; + } + r = b; +} + // ================================================================================================= } // namespace clblast diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index 16a241af..fcc1c57f 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -371,6 +371,12 @@ std::string GetDeviceVendor(const Device& device); std::string GetDeviceArchitecture(const Device& device); std::string GetDeviceName(const Device& device); +// ================================================================================================= + +// Solve Bezout's identity +// a * p + b * q = r = GCD(a, b) +void EuclidGCD(int a, int b, int &p, int &q, int &r); + // ================================================================================================= } // namespace clblast -- cgit v1.2.3