// ================================================================================================= // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- // width of 100 characters per line. // // Author(s): // Cedric Nugteren // // This file implements the Xcol2im class (see the header for information about the class). // // ================================================================================================= #include "routines/levelx/xcol2im.hpp" #include #include namespace clblast { // ================================================================================================= // Constructor: forwards to base class constructor template Xcol2im::Xcol2im(Queue &queue, EventPointer event, const std::string &name): Routine(queue, event, name, {"Copy"}, PrecisionValue(), {}, { #include "../../kernels/levelx/col2im.opencl" }) { } // ================================================================================================= // The main routine template void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const Buffer &col_buffer, const size_t col_offset, const Buffer &im_buffer, const size_t im_offset) { // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } // Sets the output height and width const auto size_h = height + 2 * pad_h; const auto padding_h = dilation_h * (kernel_h - 1) + 1; const auto col_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1; const auto size_w = width + 2 * pad_w; const auto padding_w = dilation_w * (kernel_w - 1) + 1; const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; int stride_bez_h = 0; int stride_bez_w = 0; int dilation_bez_h = 0; int dilation_bez_w = 0; int gcd_h = 0; int gcd_w = 0; EuclidGCD(static_cast(stride_h), static_cast(dilation_h), stride_bez_h, dilation_bez_h, gcd_h); EuclidGCD(static_cast(stride_w), static_cast(dilation_w), stride_bez_w, dilation_bez_w, gcd_w); // Retrieves the kernel from the compiled binary auto kernel = Kernel(program_, "col2im"); // Sets the kernel arguments kernel.SetArgument(0, static_cast(height)); kernel.SetArgument(1, static_cast(width)); kernel.SetArgument(2, static_cast(channels)); kernel.SetArgument(3, static_cast(col_h)); kernel.SetArgument(4, static_cast(col_w)); kernel.SetArgument(5, static_cast(kernel_h)); kernel.SetArgument(6, static_cast(kernel_w)); kernel.SetArgument(7, static_cast(pad_h)); kernel.SetArgument(8, static_cast(pad_w)); kernel.SetArgument(9, static_cast(stride_h)); kernel.SetArgument(10, static_cast(stride_w)); kernel.SetArgument(11, static_cast(dilation_h)); kernel.SetArgument(12, static_cast(dilation_w)); kernel.SetArgument(13, stride_bez_h); kernel.SetArgument(14, stride_bez_w); kernel.SetArgument(15, dilation_bez_h); kernel.SetArgument(16, dilation_bez_w); kernel.SetArgument(17, gcd_h); kernel.SetArgument(18, gcd_w); kernel.SetArgument(19, col_buffer()); kernel.SetArgument(20, static_cast(col_offset)); kernel.SetArgument(21, im_buffer()); kernel.SetArgument(22, static_cast(im_offset)); // Launches the kernel const auto w_ceiled = Ceil((width - 1) / gcd_w + 1, db_["COPY_DIMX"]); const auto h_ceiled = Ceil((height - 1) / gcd_h + 1, db_["COPY_DIMY"]); const auto global = std::vector{w_ceiled, h_ceiled * channels}; const auto local = std::vector{db_["COPY_DIMX"], db_["COPY_DIMY"]}; RunKernel(kernel, queue_, device_, global, local, event_); } // ================================================================================================= // Compiles the templated class template class Xcol2im; template class Xcol2im; template class Xcol2im; template class Xcol2im; template class Xcol2im; // ================================================================================================= } // namespace clblast