diff options
author | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-12-18 13:56:00 +0900 |
---|---|---|
committer | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-12-18 13:56:00 +0900 |
commit | 301dc280dfe75ff3c8b219f64aea83a6bf2f0c8d (patch) | |
tree | e47fd45f74bc0dd326e6b120af341861b47f90aa /src/routines/levelx/xconvgemm.cpp | |
parent | 9819957768174dbb4929b970718a0d6018520979 (diff) |
Fix xconvgemm kernel and enable ConvGemmMethod::kSingleKernel
Diffstat (limited to 'src/routines/levelx/xconvgemm.cpp')
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index 88127b0f..8bd24f15 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -53,9 +53,6 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode, const Buffer<T> &kernel_buffer, const size_t kernel_offset, const Buffer<T> &result_buffer, const size_t result_offset) { - // TODO: Implement single-kernel approach - assert(method_ == ConvGemmMethod::kWithIm2Col); - // Tests for a valid batch count if (batch_count == 0) { throw BLASError(StatusCode::kInvalidBatchCount); @@ -121,7 +118,12 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode, } // Retrieves the proper XgemmDirect kernel from the compiled binary - auto kernel = Kernel(program_, "Xconvgemm"); + const std::string kernel_name = (method_ == ConvGemmMethod::kWithIm2Col) + ? "Xconvgemm" + : (kernel_mode == KernelMode::kConvolution) + ? "XconvgemmFlip" + : "XconvgemmNormal"; + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(num_patches)); |