summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xcol2im.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/levelx/xcol2im.cpp')
-rw-r--r--src/routines/levelx/xcol2im.cpp27
1 files changed, 21 insertions, 6 deletions
diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp
index 8339c02c..7a0c36b7 100644
--- a/src/routines/levelx/xcol2im.cpp
+++ b/src/routines/levelx/xcol2im.cpp
@@ -49,6 +49,15 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
const auto padding_w = dilation_w * (kernel_w - 1) + 1;
const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
+ int stride_bez_h = 0;
+ int stride_bez_w = 0;
+ int dilation_bez_h = 0;
+ int dilation_bez_w = 0;
+ int gcd_h = 0;
+ int gcd_w = 0;
+ EuclidGCD(static_cast<int>(stride_h), static_cast<int>(dilation_h), stride_bez_h, dilation_bez_h, gcd_h);
+ EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w);
+
// Retrieves the kernel from the compiled binary
auto kernel = Kernel(program_, "col2im");
@@ -66,14 +75,20 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
kernel.SetArgument(10, static_cast<int>(stride_w));
kernel.SetArgument(11, static_cast<int>(dilation_h));
kernel.SetArgument(12, static_cast<int>(dilation_w));
- kernel.SetArgument(13, col_buffer());
- kernel.SetArgument(14, static_cast<int>(col_offset));
- kernel.SetArgument(15, im_buffer());
- kernel.SetArgument(16, static_cast<int>(im_offset));
+ kernel.SetArgument(13, stride_bez_h);
+ kernel.SetArgument(14, stride_bez_w);
+ kernel.SetArgument(15, dilation_bez_h);
+ kernel.SetArgument(16, dilation_bez_w);
+ kernel.SetArgument(17, gcd_h);
+ kernel.SetArgument(18, gcd_w);
+ kernel.SetArgument(19, col_buffer());
+ kernel.SetArgument(20, static_cast<int>(col_offset));
+ kernel.SetArgument(21, im_buffer());
+ kernel.SetArgument(22, static_cast<int>(im_offset));
// Launches the kernel
- const auto w_ceiled = Ceil(col_w, db_["COPY_DIMX"]);
- const auto h_ceiled = Ceil(col_h, db_["COPY_DIMY"]);
+ const auto w_ceiled = Ceil((width - 1) / gcd_w + 1, db_["COPY_DIMX"]);
+ const auto h_ceiled = Ceil((height - 1) / gcd_h + 1, db_["COPY_DIMY"]);
const auto global = std::vector<size_t>{w_ceiled, h_ceiled * channels};
const auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]};
RunKernel(kernel, queue_, device_, global, local, event_);