diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-09-07 22:02:44 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-09-07 22:02:44 +0200 |
commit | c788e040f7f4e46d9f03644cadb65788fe42571e (patch) | |
tree | 46cfc164d5e0104174ee1b7ff1489eee2b23688d /src/routines/levelx | |
parent | 2dd539f911dc9e53f188ed404ba95a795ee56fb6 (diff) |
Added xCONVGEMM as im2col plus a batched GEMM kernel
Diffstat (limited to 'src/routines/levelx')
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 4 | ||||
-rw-r--r-- | src/routines/levelx/xconvgemm.hpp | 2 |
2 files changed, 5 insertions, 1 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index 5ad39751..f26f23a7 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -13,6 +13,7 @@ #include <string> #include <vector> +#include <assert.h> #include "routines/levelx/xconvgemm.hpp" #include "routines/levelx/xim2col.hpp" @@ -51,6 +52,9 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const const Buffer<T> &kernel_buffer, const size_t kernel_offset, const Buffer<T> &result_buffer, const size_t result_offset) { + // TODO: Implement single-kernel approach + assert(method_ == ConvGemmMethod::kWithIm2Col); + // Tests for a valid batch count if (batch_count == 0) { throw BLASError(StatusCode::kInvalidBatchCount); diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp index ac27657f..9d11ccee 100644 --- a/src/routines/levelx/xconvgemm.hpp +++ b/src/routines/levelx/xconvgemm.hpp @@ -29,7 +29,7 @@ class Xconvgemm: public Routine { // Constructor enum class ConvGemmMethod {kWithIm2Col, kSingleKernel}; Xconvgemm(Queue &queue, EventPointer event, const std::string &name = "CONVGEMM", - const ConvGemmMethod method = ConvGemmMethod::kSingleKernel); + const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col); // Templated-precision implementation of the routine void DoConvgemm(const size_t channels, const size_t height, const size_t width, |