summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-08-31 21:58:12 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-08-31 21:58:12 +0200
commit297159d5b99f33f4e49cea238e66f1a1f05389a3 (patch)
tree9f5697b515a1d8643bbc1951e8f27b6a88c02211
parent6194d43efba30aac90a64676e7770f020e4a5588 (diff)
Fixed a bug in im2col: process only valid channel IDs
-rw-r--r--src/kernels/levelx/im2col.opencl4
-rw-r--r--src/routines/levelx/xim2col.cpp29
2 files changed, 17 insertions, 16 deletions
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl
index c3a5e419..a64d6538 100644
--- a/src/kernels/levelx/im2col.opencl
+++ b/src/kernels/levelx/im2col.opencl
@@ -26,7 +26,7 @@ R"(
// =================================================================================================
__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void im2col(const int input_h, const int input_w,
+void im2col(const int input_h, const int input_w, const int channels,
const int output_h, const int output_w,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
@@ -39,7 +39,7 @@ void im2col(const int input_h, const int input_w,
const int w_id = get_global_id(0); // image width, max 'output_w'
const int h_id = get_global_id(1) % output_h; // image height, max 'output_h'
const int c_id = get_global_id(1) / output_h; // input channels
- if (h_id < output_h && w_id < output_w) {
+ if (h_id < output_h && w_id < output_w && c_id < channels) {
#pragma unroll
for (int kh_id = 0; kh_id < kernel_h; ++kh_id) { // kernel height
diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp
index 527695c0..dfbb4bb5 100644
--- a/src/routines/levelx/xim2col.cpp
+++ b/src/routines/levelx/xim2col.cpp
@@ -55,20 +55,21 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
kernel.SetArgument(1, static_cast<int>(width));
- kernel.SetArgument(2, static_cast<int>(output_h));
- kernel.SetArgument(3, static_cast<int>(output_w));
- kernel.SetArgument(4, static_cast<int>(kernel_h));
- kernel.SetArgument(5, static_cast<int>(kernel_w));
- kernel.SetArgument(6, static_cast<int>(pad_h));
- kernel.SetArgument(7, static_cast<int>(pad_w));
- kernel.SetArgument(8, static_cast<int>(stride_h));
- kernel.SetArgument(9, static_cast<int>(stride_w));
- kernel.SetArgument(10, static_cast<int>(dilation_h));
- kernel.SetArgument(11, static_cast<int>(dilation_w));
- kernel.SetArgument(12, im_buffer());
- kernel.SetArgument(13, static_cast<int>(im_offset));
- kernel.SetArgument(14, col_buffer());
- kernel.SetArgument(15, static_cast<int>(col_offset));
+ kernel.SetArgument(2, static_cast<int>(channels));
+ kernel.SetArgument(3, static_cast<int>(output_h));
+ kernel.SetArgument(4, static_cast<int>(output_w));
+ kernel.SetArgument(5, static_cast<int>(kernel_h));
+ kernel.SetArgument(6, static_cast<int>(kernel_w));
+ kernel.SetArgument(7, static_cast<int>(pad_h));
+ kernel.SetArgument(8, static_cast<int>(pad_w));
+ kernel.SetArgument(9, static_cast<int>(stride_h));
+ kernel.SetArgument(10, static_cast<int>(stride_w));
+ kernel.SetArgument(11, static_cast<int>(dilation_h));
+ kernel.SetArgument(12, static_cast<int>(dilation_w));
+ kernel.SetArgument(13, im_buffer());
+ kernel.SetArgument(14, static_cast<int>(im_offset));
+ kernel.SetArgument(15, col_buffer());
+ kernel.SetArgument(16, static_cast<int>(col_offset));
// Launches the kernel
const auto w_ceiled = Ceil(output_w, db_["COPY_DIMX"]);