summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xconvgemm.cpp
diff options
context:
space:
mode:
authorKoichi Akabe <vbkaisetsu@gmail.com>2018-12-18 13:56:00 +0900
committerKoichi Akabe <vbkaisetsu@gmail.com>2018-12-18 13:56:00 +0900
commit301dc280dfe75ff3c8b219f64aea83a6bf2f0c8d (patch)
treee47fd45f74bc0dd326e6b120af341861b47f90aa /src/routines/levelx/xconvgemm.cpp
parent9819957768174dbb4929b970718a0d6018520979 (diff)
Fix xconvgemm kernel and enable ConvGemmMethod::kSingleKernel
Diffstat (limited to 'src/routines/levelx/xconvgemm.cpp')
-rw-r--r--src/routines/levelx/xconvgemm.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index 88127b0f..8bd24f15 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -53,9 +53,6 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
const Buffer<T> &result_buffer, const size_t result_offset) {
- // TODO: Implement single-kernel approach
- assert(method_ == ConvGemmMethod::kWithIm2Col);
-
// Tests for a valid batch count
if (batch_count == 0) {
throw BLASError(StatusCode::kInvalidBatchCount);
@@ -121,7 +118,12 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
}
// Retrieves the proper XgemmDirect kernel from the compiled binary
- auto kernel = Kernel(program_, "Xconvgemm");
+ const std::string kernel_name = (method_ == ConvGemmMethod::kWithIm2Col)
+ ? "Xconvgemm"
+ : (kernel_mode == KernelMode::kConvolution)
+ ? "XconvgemmFlip"
+ : "XconvgemmNormal";
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(num_patches));