From 803ca781f9be56f86a0806689f8886a2428d5b9f Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 19 Aug 2017 18:25:13 +0200 Subject: First version of im2col kernel, unoptimized but working --- src/routines/levelx/xim2col.cpp | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) (limited to 'src/routines/levelx/xim2col.cpp') diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp index 150220d6..10c9c10c 100644 --- a/src/routines/levelx/xim2col.cpp +++ b/src/routines/levelx/xim2col.cpp @@ -23,7 +23,7 @@ namespace clblast { template Xim2col::Xim2col(Queue &queue, EventPointer event, const std::string &name): Routine(queue, event, name, {}, PrecisionValue(), {}, { -#include "../../kernels/level3/level3.opencl" +#include "../../kernels/levelx/im2col.opencl" }) { } @@ -40,6 +40,42 @@ void Xim2col::DoIm2col(const size_t channels, const size_t height, const size // Makes sure all dimensions are larger than zero if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Sets the output height and width + const auto size_h = height + 2 * pad_h; + const auto padding_h = dilation_h * (kernel_h - 1) + 1; + const auto output_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1; + const auto size_w = width + 2 * pad_w; + const auto padding_w = dilation_w * (kernel_w - 1) + 1; + const auto output_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; + + // Retrieves the Xcopy kernel from the compiled binary + auto kernel = Kernel(program_, "im2col"); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast(height)); + kernel.SetArgument(1, static_cast(width)); + kernel.SetArgument(2, static_cast(output_h)); + kernel.SetArgument(3, static_cast(output_w)); + kernel.SetArgument(4, static_cast(kernel_h)); + kernel.SetArgument(5, static_cast(kernel_w)); + kernel.SetArgument(6, static_cast(pad_h)); + kernel.SetArgument(7, static_cast(pad_w)); + kernel.SetArgument(8, static_cast(stride_h)); + kernel.SetArgument(9, static_cast(stride_w)); + kernel.SetArgument(10, static_cast(dilation_h)); + kernel.SetArgument(11, static_cast(dilation_w)); + kernel.SetArgument(12, im_buffer()); + kernel.SetArgument(13, static_cast(im_offset)); + kernel.SetArgument(14, col_buffer()); + kernel.SetArgument(15, static_cast(col_offset)); + + // Launches the kernel + const auto h_ceiled = Ceil(output_h, 16); + const auto w_ceiled = Ceil(output_w, 16); + auto global = std::vector{h_ceiled, w_ceiled, channels}; + auto local = std::vector{16, 16, 1}; + RunKernel(kernel, queue_, device_, global, local, event_); } // ================================================================================================= -- cgit v1.2.3