summaryrefslogtreecommitdiff
path: root/src/kernels
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-08-31 21:58:12 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-08-31 21:58:12 +0200
commit297159d5b99f33f4e49cea238e66f1a1f05389a3 (patch)
tree9f5697b515a1d8643bbc1951e8f27b6a88c02211 /src/kernels
parent6194d43efba30aac90a64676e7770f020e4a5588 (diff)
Fixed a bug in im2col: process only valid channel IDs
Diffstat (limited to 'src/kernels')
-rw-r--r--src/kernels/levelx/im2col.opencl4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/kernels/levelx/im2col.opencl b/src/kernels/levelx/im2col.opencl
index c3a5e419..a64d6538 100644
--- a/src/kernels/levelx/im2col.opencl
+++ b/src/kernels/levelx/im2col.opencl
@@ -26,7 +26,7 @@ R"(
// =================================================================================================
__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
-void im2col(const int input_h, const int input_w,
+void im2col(const int input_h, const int input_w, const int channels,
const int output_h, const int output_w,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
@@ -39,7 +39,7 @@ void im2col(const int input_h, const int input_w,
const int w_id = get_global_id(0); // image width, max 'output_w'
const int h_id = get_global_id(1) % output_h; // image height, max 'output_h'
const int c_id = get_global_id(1) / output_h; // input channels
- if (h_id < output_h && w_id < output_w) {
+ if (h_id < output_h && w_id < output_w && c_id < channels) {
#pragma unroll
for (int kh_id = 0; kh_id < kernel_h; ++kh_id) { // kernel height