From 0b3d04f70902e00f86c572a5e3c379f9335b216f Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 30 Oct 2018 14:54:55 +0900 Subject: Fix col2im implementation --- src/routines/levelx/xcol2im.cpp | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'src/routines/levelx/xcol2im.cpp') diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp index 8339c02c..7a0c36b7 100644 --- a/src/routines/levelx/xcol2im.cpp +++ b/src/routines/levelx/xcol2im.cpp @@ -49,6 +49,15 @@ void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size const auto padding_w = dilation_w * (kernel_w - 1) + 1; const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; + int stride_bez_h = 0; + int stride_bez_w = 0; + int dilation_bez_h = 0; + int dilation_bez_w = 0; + int gcd_h = 0; + int gcd_w = 0; + EuclidGCD(static_cast(stride_h), static_cast(dilation_h), stride_bez_h, dilation_bez_h, gcd_h); + EuclidGCD(static_cast(stride_w), static_cast(dilation_w), stride_bez_w, dilation_bez_w, gcd_w); + // Retrieves the kernel from the compiled binary auto kernel = Kernel(program_, "col2im"); @@ -66,14 +75,20 @@ void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size kernel.SetArgument(10, static_cast(stride_w)); kernel.SetArgument(11, static_cast(dilation_h)); kernel.SetArgument(12, static_cast(dilation_w)); - kernel.SetArgument(13, col_buffer()); - kernel.SetArgument(14, static_cast(col_offset)); - kernel.SetArgument(15, im_buffer()); - kernel.SetArgument(16, static_cast(im_offset)); + kernel.SetArgument(13, stride_bez_h); + kernel.SetArgument(14, stride_bez_w); + kernel.SetArgument(15, dilation_bez_h); + kernel.SetArgument(16, dilation_bez_w); + kernel.SetArgument(17, gcd_h); + kernel.SetArgument(18, gcd_w); + kernel.SetArgument(19, col_buffer()); + kernel.SetArgument(20, static_cast(col_offset)); + kernel.SetArgument(21, im_buffer()); + kernel.SetArgument(22, static_cast(im_offset)); // Launches the kernel - const auto w_ceiled = Ceil(col_w, db_["COPY_DIMX"]); - const auto h_ceiled = Ceil(col_h, db_["COPY_DIMY"]); + const auto w_ceiled = Ceil((width - 1) / gcd_w + 1, db_["COPY_DIMX"]); + const auto h_ceiled = Ceil((height - 1) / gcd_h + 1, db_["COPY_DIMY"]); const auto global = std::vector{w_ceiled, h_ceiled * channels}; const auto local = std::vector{db_["COPY_DIMX"], db_["COPY_DIMY"]}; RunKernel(kernel, queue_, device_, global, local, event_); -- cgit v1.2.3