summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-08-31 21:58:12 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-08-31 21:58:12 +0200
commit297159d5b99f33f4e49cea238e66f1a1f05389a3 (patch)
tree9f5697b515a1d8643bbc1951e8f27b6a88c02211 /src/routines
parent6194d43efba30aac90a64676e7770f020e4a5588 (diff)
Fixed a bug in im2col: process only valid channel IDs
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/levelx/xim2col.cpp29
1 files changed, 15 insertions, 14 deletions
diff --git a/src/routines/levelx/xim2col.cpp b/src/routines/levelx/xim2col.cpp
index 527695c0..dfbb4bb5 100644
--- a/src/routines/levelx/xim2col.cpp
+++ b/src/routines/levelx/xim2col.cpp
@@ -55,20 +55,21 @@ void Xim2col<T>::DoIm2col(const size_t channels, const size_t height, const size
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));
kernel.SetArgument(1, static_cast<int>(width));
- kernel.SetArgument(2, static_cast<int>(output_h));
- kernel.SetArgument(3, static_cast<int>(output_w));
- kernel.SetArgument(4, static_cast<int>(kernel_h));
- kernel.SetArgument(5, static_cast<int>(kernel_w));
- kernel.SetArgument(6, static_cast<int>(pad_h));
- kernel.SetArgument(7, static_cast<int>(pad_w));
- kernel.SetArgument(8, static_cast<int>(stride_h));
- kernel.SetArgument(9, static_cast<int>(stride_w));
- kernel.SetArgument(10, static_cast<int>(dilation_h));
- kernel.SetArgument(11, static_cast<int>(dilation_w));
- kernel.SetArgument(12, im_buffer());
- kernel.SetArgument(13, static_cast<int>(im_offset));
- kernel.SetArgument(14, col_buffer());
- kernel.SetArgument(15, static_cast<int>(col_offset));
+ kernel.SetArgument(2, static_cast<int>(channels));
+ kernel.SetArgument(3, static_cast<int>(output_h));
+ kernel.SetArgument(4, static_cast<int>(output_w));
+ kernel.SetArgument(5, static_cast<int>(kernel_h));
+ kernel.SetArgument(6, static_cast<int>(kernel_w));
+ kernel.SetArgument(7, static_cast<int>(pad_h));
+ kernel.SetArgument(8, static_cast<int>(pad_w));
+ kernel.SetArgument(9, static_cast<int>(stride_h));
+ kernel.SetArgument(10, static_cast<int>(stride_w));
+ kernel.SetArgument(11, static_cast<int>(dilation_h));
+ kernel.SetArgument(12, static_cast<int>(dilation_w));
+ kernel.SetArgument(13, im_buffer());
+ kernel.SetArgument(14, static_cast<int>(im_offset));
+ kernel.SetArgument(15, col_buffer());
+ kernel.SetArgument(16, static_cast<int>(col_offset));
// Launches the kernel
const auto w_ceiled = Ceil(output_w, db_["COPY_DIMX"]);