diff options
Diffstat (limited to 'src/routines/levelx/xconvgemm.cpp')
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index f26f23a7..88127b0f 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -43,7 +43,8 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam // ================================================================================================= template <typename T> -void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const size_t width, +void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, @@ -94,7 +95,8 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const const auto col_batch_offset = batch_id * patch_size * num_patches; auto im2col_event = Event(); auto im2col = Xim2col<T>(queue_, im2col_event.pointer()); - im2col.DoIm2col(channels, height, width, kernel_h, kernel_w, + im2col.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, im_buffer, im_batch_offset, col_buffer, col_batch_offset); |