summaryrefslogtreecommitdiff
path: root/src/routines/levelx
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-09-07 22:02:44 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-09-07 22:02:44 +0200
commitc788e040f7f4e46d9f03644cadb65788fe42571e (patch)
tree46cfc164d5e0104174ee1b7ff1489eee2b23688d /src/routines/levelx
parent2dd539f911dc9e53f188ed404ba95a795ee56fb6 (diff)
Added xCONVGEMM as im2col plus a batched GEMM kernel
Diffstat (limited to 'src/routines/levelx')
-rw-r--r--src/routines/levelx/xconvgemm.cpp4
-rw-r--r--src/routines/levelx/xconvgemm.hpp2
2 files changed, 5 insertions, 1 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index 5ad39751..f26f23a7 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -13,6 +13,7 @@
#include <string>
#include <vector>
+#include <assert.h>
#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xim2col.hpp"
@@ -51,6 +52,9 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
const Buffer<T> &result_buffer, const size_t result_offset) {
+ // TODO: Implement single-kernel approach
+ assert(method_ == ConvGemmMethod::kWithIm2Col);
+
// Tests for a valid batch count
if (batch_count == 0) {
throw BLASError(StatusCode::kInvalidBatchCount);
diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp
index ac27657f..9d11ccee 100644
--- a/src/routines/levelx/xconvgemm.hpp
+++ b/src/routines/levelx/xconvgemm.hpp
@@ -29,7 +29,7 @@ class Xconvgemm: public Routine {
// Constructor
enum class ConvGemmMethod {kWithIm2Col, kSingleKernel};
Xconvgemm(Queue &queue, EventPointer event, const std::string &name = "CONVGEMM",
- const ConvGemmMethod method = ConvGemmMethod::kSingleKernel);
+ const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col);
// Templated-precision implementation of the routine
void DoConvgemm(const size_t channels, const size_t height, const size_t width,