diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-08-31 21:58:12 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-08-31 21:58:12 +0200 |
commit | 297159d5b99f33f4e49cea238e66f1a1f05389a3 (patch) | |
tree | 9f5697b515a1d8643bbc1951e8f27b6a88c02211 /src | |
parent | 6194d43efba30aac90a64676e7770f020e4a5588 (diff) |
Fixed a bug in im2col: process only valid channel IDs
Diffstat (limited to 'src')
-rw-r--r-- | src/kernels/levelx/im2col.opencl | 4 | ||||
-rw-r--r-- | src/routines/levelx/xim2col.cpp | 29 |
2 files changed, 17 insertions, 16 deletions
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl index c3a5e419..a64d6538 100644 --- a/src/kernels/levelx/im2col.opencl +++ b/src/kernels/levelx/im2col.opencl @@ -26,7 +26,7 @@ R"( // ================================================================================================= __kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) -void im2col(const int input_h, const int input_w, +void im2col(const int input_h, const int input_w, const int channels, const int output_h, const int output_w, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, @@ -39,7 +39,7 @@ void im2col(const int input_h, const int input_w, const int w_id = get_global_id(0); // image width, max 'output_w' const int h_id = get_global_id(1) % output_h; // image height, max 'output_h' const int c_id = get_global_id(1) / output_h; // input channels - if (h_id < output_h && w_id < output_w) { + if (h_id < output_h && w_id < output_w && c_id < channels) { #pragma unroll for (int kh_id = 0; kh_id < kernel_h; ++kh_id) { // kernel height diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp index 527695c0..dfbb4bb5 100644 --- a/src/routines/levelx/xim2col.cpp +++ b/src/routines/levelx/xim2col.cpp @@ -55,20 +55,21 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(height)); kernel.SetArgument(1, static_cast<int>(width)); - kernel.SetArgument(2, static_cast<int>(output_h)); - kernel.SetArgument(3, static_cast<int>(output_w)); - kernel.SetArgument(4, static_cast<int>(kernel_h)); - kernel.SetArgument(5, static_cast<int>(kernel_w)); - kernel.SetArgument(6, static_cast<int>(pad_h)); - kernel.SetArgument(7, static_cast<int>(pad_w)); - kernel.SetArgument(8, static_cast<int>(stride_h)); - kernel.SetArgument(9, static_cast<int>(stride_w)); - kernel.SetArgument(10, static_cast<int>(dilation_h)); - kernel.SetArgument(11, static_cast<int>(dilation_w)); - kernel.SetArgument(12, im_buffer()); - kernel.SetArgument(13, static_cast<int>(im_offset)); - kernel.SetArgument(14, col_buffer()); - kernel.SetArgument(15, static_cast<int>(col_offset)); + kernel.SetArgument(2, static_cast<int>(channels)); + kernel.SetArgument(3, static_cast<int>(output_h)); + kernel.SetArgument(4, static_cast<int>(output_w)); + kernel.SetArgument(5, static_cast<int>(kernel_h)); + kernel.SetArgument(6, static_cast<int>(kernel_w)); + kernel.SetArgument(7, static_cast<int>(pad_h)); + kernel.SetArgument(8, static_cast<int>(pad_w)); + kernel.SetArgument(9, static_cast<int>(stride_h)); + kernel.SetArgument(10, static_cast<int>(stride_w)); + kernel.SetArgument(11, static_cast<int>(dilation_h)); + kernel.SetArgument(12, static_cast<int>(dilation_w)); + kernel.SetArgument(13, im_buffer()); + kernel.SetArgument(14, static_cast<int>(im_offset)); + kernel.SetArgument(15, col_buffer()); + kernel.SetArgument(16, static_cast<int>(col_offset)); // Launches the kernel const auto w_ceiled = Ceil(output_w, db_["COPY_DIMX"]); |