diff options
Diffstat (limited to 'src/routines/levelx')
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 12 | ||||
-rw-r--r-- | src/routines/levelx/xconvgemm.hpp | 2 |
2 files changed, 8 insertions, 6 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index 88127b0f..d137e6fe 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -25,7 +25,7 @@ namespace clblast { template <typename T> Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name, const ConvGemmMethod method): - Routine(queue, event, name, {"XgemmDirect"}, + Routine(queue, event, name, {"Xconvgemm"}, PrecisionValue<T>(), {}, { (method == ConvGemmMethod::kWithIm2Col) ? "#define CONVGEMM_WITH_IM2COL\n" : "", #include "../../kernels/level3/level3.opencl" @@ -53,9 +53,6 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode, const Buffer<T> &kernel_buffer, const size_t kernel_offset, const Buffer<T> &result_buffer, const size_t result_offset) { - // TODO: Implement single-kernel approach - assert(method_ == ConvGemmMethod::kWithIm2Col); - // Tests for a valid batch count if (batch_count == 0) { throw BLASError(StatusCode::kInvalidBatchCount); @@ -121,7 +118,12 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode, } // Retrieves the proper XgemmDirect kernel from the compiled binary - auto kernel = Kernel(program_, "Xconvgemm"); + const std::string kernel_name = (method_ == ConvGemmMethod::kWithIm2Col) + ? "Xconvgemm" + : (kernel_mode == KernelMode::kConvolution) + ? "XconvgemmFlip" + : "XconvgemmNormal"; + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(num_patches)); diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp index 20cfff60..16082fc6 100644 --- a/src/routines/levelx/xconvgemm.hpp +++ b/src/routines/levelx/xconvgemm.hpp @@ -29,7 +29,7 @@ class Xconvgemm: public Routine { // Constructor enum class ConvGemmMethod {kWithIm2Col, kSingleKernel}; Xconvgemm(Queue &queue, EventPointer event, const std::string &name = "CONVGEMM", - const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col); + const ConvGemmMethod method = ConvGemmMethod::kSingleKernel); // Templated-precision implementation of the routine void DoConvgemm(const KernelMode kernel_mode, |