summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xcol2im.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/levelx/xcol2im.cpp')
-rw-r--r--src/routines/levelx/xcol2im.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp
index 7a0c36b7..d285e5c0 100644
--- a/src/routines/levelx/xcol2im.cpp
+++ b/src/routines/levelx/xcol2im.cpp
@@ -31,13 +31,17 @@ Xcol2im<T>::Xcol2im(Queue &queue, EventPointer event, const std::string &name):
// The main routine
template <typename T>
-void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size_t width,
+void Xcol2im<T>::DoCol2im(const KernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width,
const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
const size_t pad_w, const size_t stride_h, const size_t stride_w,
const size_t dilation_h, const size_t dilation_w,
const Buffer<T> &col_buffer, const size_t col_offset,
const Buffer<T> &im_buffer, const size_t im_offset) {
+ // Flip the output along kernel_h and kernel_w, or not.
+ const auto kernel_name = (kernel_mode == KernelMode::kConvolution) ? "Xcol2imKernelFlip" : "Xcol2imKernelNormal";
+
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
@@ -59,7 +63,7 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w);
// Retrieves the kernel from the compiled binary
- auto kernel = Kernel(program_, "col2im");
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(height));