summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cpp36
-rw-r--r--src/clblast_c.cpp67
-rw-r--r--src/clblast_cuda.cpp38
-rw-r--r--src/clblast_netlib_c.cpp90
-rw-r--r--src/kernels/levelx/col2im.opencl95
-rw-r--r--src/routines/levelx/xcol2im.cpp107
-rw-r--r--src/routines/levelx/xcol2im.hpp45
-rw-r--r--src/routines/levelx/xim2col.cpp16
-rw-r--r--src/routines/levelx/xim2col.hpp1
-rw-r--r--src/routines/routines.hpp1
-rw-r--r--src/utilities/utilities.cpp26
-rw-r--r--src/utilities/utilities.hpp6
12 files changed, 520 insertions, 8 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 0cd2f843..e45f504a 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2252,6 +2252,42 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
cl_mem, const size_t,
cl_command_queue*, cl_event*);
+// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
+template <typename T>
+StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const cl_mem col_buffer, const size_t col_offset,
+ cl_mem im_buffer, const size_t im_offset,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = Xcol2im<T>(queue_cpp, event);
+ routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ Buffer<T>(col_buffer), col_offset,
+ Buffer<T>(im_buffer), im_offset);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const cl_mem, const size_t,
+ cl_mem, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const cl_mem, const size_t,
+ cl_mem, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const cl_mem, const size_t,
+ cl_mem, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const cl_mem, const size_t,
+ cl_mem, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const cl_mem, const size_t,
+ cl_mem, const size_t,
+ cl_command_queue*, cl_event*);
+
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 72adb888..645a69b1 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3679,6 +3679,73 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
+// COL2IM
+CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const cl_mem col_buffer, const size_t col_offset,
+ cl_mem im_buffer, const size_t im_offset,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer, col_offset,
+ im_buffer, im_offset,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const cl_mem col_buffer, const size_t col_offset,
+ cl_mem im_buffer, const size_t im_offset,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer, col_offset,
+ im_buffer, im_offset,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const cl_mem col_buffer, const size_t col_offset,
+ cl_mem im_buffer, const size_t im_offset,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer, col_offset,
+ im_buffer, im_offset,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const cl_mem col_buffer, const size_t col_offset,
+ cl_mem im_buffer, const size_t im_offset,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer, col_offset,
+ im_buffer, im_offset,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const cl_mem col_buffer, const size_t col_offset,
+ cl_mem im_buffer, const size_t im_offset,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::Col2im<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer, col_offset,
+ im_buffer, im_offset,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
// CONVGEMM
CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index f14806cb..03d995ba 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2350,6 +2350,44 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
+// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM
+template <typename T>
+StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+ const CUdeviceptr col_buffer, const size_t col_offset,
+ CUdeviceptr im_buffer, const size_t im_offset,
+ const CUcontext context, const CUdevice device) {
+ try {
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
+ auto routine = Xcol2im<T>(queue_cpp, nullptr);
+ routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ Buffer<T>(col_buffer), col_offset,
+ Buffer<T>(im_buffer), im_offset);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t,
+ CUdeviceptr, const size_t,
+ const CUcontext, const CUdevice);
+
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
diff --git a/src/clblast_netlib_c.cpp b/src/clblast_netlib_c.cpp
index dbc2ba57..22570535 100644
--- a/src/clblast_netlib_c.cpp
+++ b/src/clblast_netlib_c.cpp
@@ -4967,4 +4967,94 @@ void cblas_zim2col(const int channels, const int height, const int width, const
col_buffer.Read(queue, col_size, reinterpret_cast<double2*>(col));
}
+// COL2IM
+void cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+ const float* col,
+ float* im) {
+ OPTIONAL_STATIC auto device = get_device();
+ OPTIONAL_STATIC auto context = clblast::Context(device);
+ auto queue = clblast::Queue(context, device);
+ const auto col_size = height * width * channels;
+ const auto im_size = height * width * channels;
+ auto col_buffer = clblast::Buffer<float>(context, col_size);
+ auto im_buffer = clblast::Buffer<float>(context, im_size);
+ col_buffer.Write(queue, col_size, reinterpret_cast<const float*>(col));
+ im_buffer.Write(queue, im_size, reinterpret_cast<float*>(im));
+ auto queue_cl = queue();
+ auto s = clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer(), 0,
+ im_buffer(), 0,
+ &queue_cl);
+ if (s != clblast::StatusCode::kSuccess) {
+ throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s));
+ }
+ im_buffer.Read(queue, im_size, reinterpret_cast<float*>(im));
+}
+void cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+ const double* col,
+ double* im) {
+ OPTIONAL_STATIC auto device = get_device();
+ OPTIONAL_STATIC auto context = clblast::Context(device);
+ auto queue = clblast::Queue(context, device);
+ const auto col_size = height * width * channels;
+ const auto im_size = height * width * channels;
+ auto col_buffer = clblast::Buffer<double>(context, col_size);
+ auto im_buffer = clblast::Buffer<double>(context, im_size);
+ col_buffer.Write(queue, col_size, reinterpret_cast<const double*>(col));
+ im_buffer.Write(queue, im_size, reinterpret_cast<double*>(im));
+ auto queue_cl = queue();
+ auto s = clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer(), 0,
+ im_buffer(), 0,
+ &queue_cl);
+ if (s != clblast::StatusCode::kSuccess) {
+ throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s));
+ }
+ im_buffer.Read(queue, im_size, reinterpret_cast<double*>(im));
+}
+void cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+ const void* col,
+ void* im) {
+ OPTIONAL_STATIC auto device = get_device();
+ OPTIONAL_STATIC auto context = clblast::Context(device);
+ auto queue = clblast::Queue(context, device);
+ const auto col_size = height * width * channels;
+ const auto im_size = height * width * channels;
+ auto col_buffer = clblast::Buffer<float2>(context, col_size);
+ auto im_buffer = clblast::Buffer<float2>(context, im_size);
+ col_buffer.Write(queue, col_size, reinterpret_cast<const float2*>(col));
+ im_buffer.Write(queue, im_size, reinterpret_cast<float2*>(im));
+ auto queue_cl = queue();
+ auto s = clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer(), 0,
+ im_buffer(), 0,
+ &queue_cl);
+ if (s != clblast::StatusCode::kSuccess) {
+ throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s));
+ }
+ im_buffer.Read(queue, im_size, reinterpret_cast<float2*>(im));
+}
+void cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w,
+ const void* col,
+ void* im) {
+ OPTIONAL_STATIC auto device = get_device();
+ OPTIONAL_STATIC auto context = clblast::Context(device);
+ auto queue = clblast::Queue(context, device);
+ const auto col_size = height * width * channels;
+ const auto im_size = height * width * channels;
+ auto col_buffer = clblast::Buffer<double2>(context, col_size);
+ auto im_buffer = clblast::Buffer<double2>(context, im_size);
+ col_buffer.Write(queue, col_size, reinterpret_cast<const double2*>(col));
+ im_buffer.Write(queue, im_size, reinterpret_cast<double2*>(im));
+ auto queue_cl = queue();
+ auto s = clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ col_buffer(), 0,
+ im_buffer(), 0,
+ &queue_cl);
+ if (s != clblast::StatusCode::kSuccess) {
+ throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s));
+ }
+ im_buffer.Read(queue, im_size, reinterpret_cast<double2*>(im));
+}
+
// =================================================================================================
diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl
new file mode 100644
index 00000000..a37db24f
--- /dev/null
+++ b/src/kernels/levelx/col2im.opencl
@@ -0,0 +1,95 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// This file contains the col2im kernel, taken from:
+// https://gist.github.com/vbkaisetsu/a98299df827f9a5245635f646c1d94be
+// Credits go to https://github.com/vbkaisetsu
+//
+// =================================================================================================
+
+// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
+// literal). Comment-out this line for syntax-highlighting when developing.
+R"(
+
+// Work-group size parameters re-used from the 'copy' kernel
+#ifndef COPY_DIMX
+ #define COPY_DIMX 8 // Local workgroup size in the first dimension (w)
+#endif
+#ifndef COPY_DIMY
+ #define COPY_DIMY 8 // Local workgroup size in the second dimension (h)
+#endif
+
+// =================================================================================================
+
+inline int grid_ceil(const int x, const int step) {
+ return x > 0 ? ((x - 1) / step + 1) * step : x / step * step;
+}
+
+__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
+void col2im(const int input_h, const int input_w, const int channels,
+ const int output_h, const int output_w,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int stride_bez_h, const int stride_bez_w,
+ const int dilation_bez_h, const int dilation_bez_w,
+ const int gcd_h, const int gcd_w,
+ const __global real* restrict col_buffer, const int col_offset,
+ __global real* im_buffer, const int im_offset) {
+
+ const int input_h_scaled = (input_h - 1) / gcd_h + 1;
+
+ // Thread IDs
+ const int gcd_scale_w = get_global_id(0) + (pad_w - 1) / gcd_w + 1;
+ const int gcd_scale_h = ((int) get_global_id(1)) % input_h_scaled + (pad_h - 1) / gcd_h + 1;
+ const int c_id = ((int) get_global_id(1)) / input_h_scaled;
+
+ const int w_index = gcd_scale_w * gcd_w - pad_w;
+ const int h_index = gcd_scale_h * gcd_h - pad_h;
+ const int th_step = stride_h * dilation_h / gcd_h;
+ const int th_begin = grid_ceil(max(-stride_bez_h * gcd_scale_h * stride_h,
+ (dilation_bez_h * gcd_scale_h - kernel_h + 1) * dilation_h),
+ th_step);
+ const int th_end = min((output_h - stride_bez_h * gcd_scale_h) * stride_h,
+ (dilation_bez_h * gcd_scale_h + 1) * dilation_h);
+ const int tw_step = stride_w * dilation_w / gcd_w;
+ const int tw_begin = grid_ceil(max(-stride_bez_w * gcd_scale_w * stride_w,
+ (dilation_bez_w * gcd_scale_w - kernel_w + 1) * dilation_w),
+ tw_step);
+ const int tw_end = min((output_w - stride_bez_w * gcd_scale_w) * stride_w,
+ (dilation_bez_w * gcd_scale_w + 1) * dilation_w);
+ if (w_index < input_w && c_id < channels) {
+ real val;
+ SetToZero(val);
+ for (int th = th_begin; th < th_end; th += th_step) {
+ for (int tw = tw_begin; tw < tw_end; tw += tw_step) {
+ const int kh_id = -th / dilation_h + dilation_bez_h * gcd_scale_h;
+ const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w;
+ const int h_id = th / stride_h + stride_bez_h * gcd_scale_h;
+ const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w;
+
+ const int kernel_index = kw_id + kernel_w * kh_id;
+ const int patch_index = w_id + output_w * h_id;
+ const int output_index = patch_index + kernel_index * output_w * output_h +
+ c_id * output_w * output_h * kernel_h * kernel_w;
+ Add(val, val, col_buffer[output_index + col_offset]);
+ }
+ }
+
+ // Accumulates the resulting value with the existing im-buffer (+= val)
+ const int input_index = w_index + input_w * (h_index + input_h * c_id);
+ real im_buffer_value = im_buffer[input_index + im_offset];
+ Add(im_buffer[input_index + im_offset], im_buffer_value, val);
+ }
+}
+
+// =================================================================================================
+
+// End of the C++11 raw string literal
+)"
+
+// =================================================================================================
diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp
new file mode 100644
index 00000000..7a0c36b7
--- /dev/null
+++ b/src/routines/levelx/xcol2im.cpp
@@ -0,0 +1,107 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xcol2im class (see the header for information about the class).
+//
+// =================================================================================================
+
+#include "routines/levelx/xcol2im.hpp"
+
+#include <string>
+#include <vector>
+
+namespace clblast {
+// =================================================================================================
+
+// Constructor: forwards to base class constructor
+template <typename T>
+Xcol2im<T>::Xcol2im(Queue &queue, EventPointer event, const std::string &name):
+ Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
+#include "../../kernels/levelx/col2im.opencl"
+ }) {
+}
+
+// =================================================================================================
+
+// The main routine
+template <typename T>
+void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size_t width,
+ const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
+ const size_t pad_w, const size_t stride_h, const size_t stride_w,
+ const size_t dilation_h, const size_t dilation_w,
+ const Buffer<T> &col_buffer, const size_t col_offset,
+ const Buffer<T> &im_buffer, const size_t im_offset) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Sets the output height and width
+ const auto size_h = height + 2 * pad_h;
+ const auto padding_h = dilation_h * (kernel_h - 1) + 1;
+ const auto col_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1;
+ const auto size_w = width + 2 * pad_w;
+ const auto padding_w = dilation_w * (kernel_w - 1) + 1;
+ const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
+
+ int stride_bez_h = 0;
+ int stride_bez_w = 0;
+ int dilation_bez_h = 0;
+ int dilation_bez_w = 0;
+ int gcd_h = 0;
+ int gcd_w = 0;
+ EuclidGCD(static_cast<int>(stride_h), static_cast<int>(dilation_h), stride_bez_h, dilation_bez_h, gcd_h);
+ EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w);
+
+ // Retrieves the kernel from the compiled binary
+ auto kernel = Kernel(program_, "col2im");
+
+ // Sets the kernel arguments
+ kernel.SetArgument(0, static_cast<int>(height));
+ kernel.SetArgument(1, static_cast<int>(width));
+ kernel.SetArgument(2, static_cast<int>(channels));
+ kernel.SetArgument(3, static_cast<int>(col_h));
+ kernel.SetArgument(4, static_cast<int>(col_w));
+ kernel.SetArgument(5, static_cast<int>(kernel_h));
+ kernel.SetArgument(6, static_cast<int>(kernel_w));
+ kernel.SetArgument(7, static_cast<int>(pad_h));
+ kernel.SetArgument(8, static_cast<int>(pad_w));
+ kernel.SetArgument(9, static_cast<int>(stride_h));
+ kernel.SetArgument(10, static_cast<int>(stride_w));
+ kernel.SetArgument(11, static_cast<int>(dilation_h));
+ kernel.SetArgument(12, static_cast<int>(dilation_w));
+ kernel.SetArgument(13, stride_bez_h);
+ kernel.SetArgument(14, stride_bez_w);
+ kernel.SetArgument(15, dilation_bez_h);
+ kernel.SetArgument(16, dilation_bez_w);
+ kernel.SetArgument(17, gcd_h);
+ kernel.SetArgument(18, gcd_w);
+ kernel.SetArgument(19, col_buffer());
+ kernel.SetArgument(20, static_cast<int>(col_offset));
+ kernel.SetArgument(21, im_buffer());
+ kernel.SetArgument(22, static_cast<int>(im_offset));
+
+ // Launches the kernel
+ const auto w_ceiled = Ceil((width - 1) / gcd_w + 1, db_["COPY_DIMX"]);
+ const auto h_ceiled = Ceil((height - 1) / gcd_h + 1, db_["COPY_DIMY"]);
+ const auto global = std::vector<size_t>{w_ceiled, h_ceiled * channels};
+ const auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]};
+ RunKernel(kernel, queue_, device_, global, local, event_);
+}
+
+// =================================================================================================
+
+// Compiles the templated class
+template class Xcol2im<half>;
+template class Xcol2im<float>;
+template class Xcol2im<double>;
+template class Xcol2im<float2>;
+template class Xcol2im<double2>;
+
+// =================================================================================================
+} // namespace clblast
diff --git a/src/routines/levelx/xcol2im.hpp b/src/routines/levelx/xcol2im.hpp
new file mode 100644
index 00000000..86d68c45
--- /dev/null
+++ b/src/routines/levelx/xcol2im.hpp
@@ -0,0 +1,45 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xcol2im routine. The precision is implemented using a template argument.
+// Uses the tuning parameters from the regular copy kernel.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_ROUTINES_XCOL2IM_H_
+#define CLBLAST_ROUTINES_XCOL2IM_H_
+
+#include "routine.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class Xcol2im: public Routine {
+ public:
+
+ // Constructor
+ Xcol2im(Queue &queue, EventPointer event, const std::string &name = "COL2IM");
+
+ // Templated-precision implementation of the routine
+ void DoCol2im(const size_t channels, const size_t height, const size_t width,
+ const size_t kernel_h, const size_t kernel_w,
+ const size_t pad_h, const size_t pad_w,
+ const size_t stride_h, const size_t stride_w,
+ const size_t dilation_h, const size_t dilation_w,
+ const Buffer<T> &col_buffer, const size_t col_offset,
+ const Buffer<T> &im_buffer, const size_t im_offset);
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_ROUTINES_XCOL2IM_H_
+#endif
diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp
index dfbb4bb5..09dcc42c 100644
--- a/src/routines/levelx/xim2col.cpp
+++ b/src/routines/levelx/xim2col.cpp
@@ -41,23 +41,23 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
- // Sets the output height and width
+ // Sets the height and width of the 'col' result
const auto size_h = height + 2 * pad_h;
const auto padding_h = dilation_h * (kernel_h - 1) + 1;
- const auto output_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1;
+ const auto col_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1;
const auto size_w = width + 2 * pad_w;
const auto padding_w = dilation_w * (kernel_w - 1) + 1;
- const auto output_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
+ const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
- // Retrieves the Xcopy kernel from the compiled binary
+ // Retrieves the kernel from the compiled binary
auto kernel = Kernel(program_, "im2col");
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
kernel.SetArgument(1, static_cast<int>(width));
kernel.SetArgument(2, static_cast<int>(channels));
- kernel.SetArgument(3, static_cast<int>(output_h));
- kernel.SetArgument(4, static_cast<int>(output_w));
+ kernel.SetArgument(3, static_cast<int>(col_h));
+ kernel.SetArgument(4, static_cast<int>(col_w));
kernel.SetArgument(5, static_cast<int>(kernel_h));
kernel.SetArgument(6, static_cast<int>(kernel_w));
kernel.SetArgument(7, static_cast<int>(pad_h));
@@ -72,8 +72,8 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size
kernel.SetArgument(16, static_cast<int>(col_offset));
// Launches the kernel
- const auto w_ceiled = Ceil(output_w, db_["COPY_DIMX"]);
- const auto h_ceiled = Ceil(output_h, db_["COPY_DIMY"]);
+ const auto w_ceiled = Ceil(col_w, db_["COPY_DIMX"]);
+ const auto h_ceiled = Ceil(col_h, db_["COPY_DIMY"]);
const auto global = std::vector<size_t>{w_ceiled, h_ceiled * channels};
const auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]};
RunKernel(kernel, queue_, device_, global, local, event_);
diff --git a/src/routines/levelx/xim2col.hpp b/src/routines/levelx/xim2col.hpp
index 4448b54e..2c03b169 100644
--- a/src/routines/levelx/xim2col.hpp
+++ b/src/routines/levelx/xim2col.hpp
@@ -8,6 +8,7 @@
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the Xim2col routine. The precision is implemented using a template argument.
+// Uses the tuning parameters from the regular copy kernel.
//
// =================================================================================================
diff --git a/src/routines/routines.hpp b/src/routines/routines.hpp
index e080ed47..95475470 100644
--- a/src/routines/routines.hpp
+++ b/src/routines/routines.hpp
@@ -70,6 +70,7 @@
#include "routines/levelx/xhad.hpp"
#include "routines/levelx/xomatcopy.hpp"
#include "routines/levelx/xim2col.hpp"
+#include "routines/levelx/xcol2im.hpp"
#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xaxpybatched.hpp"
#include "routines/levelx/xgemmbatched.hpp"
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index a8fdaa19..a6cd82e7 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -489,4 +489,30 @@ std::string GetDeviceName(const Device& device) {
}
// =================================================================================================
+
+// Solve Bezout's identity
+// a * p + b * q = r = GCD(a, b)
+void EuclidGCD(int a, int b, int &p, int &q, int &r) {
+ p = 0;
+ q = 1;
+ int p_1 = 1;
+ int q_1 = 0;
+ for (;;) {
+ const int c = a % b;
+ if (c == 0) {
+ break;
+ }
+ const int p_2 = p_1;
+ const int q_2 = q_1;
+ p_1 = p;
+ q_1 = q;
+ p = p_2 - p_1 * (a / b);
+ q = q_2 - q_1 * (a / b);
+ a = b;
+ b = c;
+ }
+ r = b;
+}
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index 16a241af..fcc1c57f 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -372,6 +372,12 @@ std::string GetDeviceArchitecture(const Device& device);
std::string GetDeviceName(const Device& device);
// =================================================================================================
+
+// Solve Bezout's identity
+// a * p + b * q = r = GCD(a, b)
+void EuclidGCD(int a, int b, int &p, int &q, int &r);
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_UTILITIES_H_