summaryrefslogtreecommitdiff
path: root/src/routines/levelx
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2019-01-19 17:56:05 +0100
committerGitHub <noreply@github.com>2019-01-19 17:56:05 +0100
commit9a9c24e811ddefb6e9d462288916ff64dbf47d63 (patch)
tree43504e80dc48a3230a497df83b9e15baf50928ea /src/routines/levelx
parentafcf5dc6ebc287b392edcb6bd3ac48966ba98e3c (diff)
parent11f4c7dd936146f9b4f165d8ef69bafa3a33ad26 (diff)
Merge pull request #345 from CNugteren/convolution-fixes-and-tuner
Convolution with single kernel
Diffstat (limited to 'src/routines/levelx')
-rw-r--r--src/routines/levelx/xconvgemm.cpp12
-rw-r--r--src/routines/levelx/xconvgemm.hpp2
2 files changed, 8 insertions, 6 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index 88127b0f..d137e6fe 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -25,7 +25,7 @@ namespace clblast {
template <typename T>
Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name,
const ConvGemmMethod method):
- Routine(queue, event, name, {"XgemmDirect"},
+ Routine(queue, event, name, {"Xconvgemm"},
PrecisionValue<T>(), {}, {
(method == ConvGemmMethod::kWithIm2Col) ? "#define CONVGEMM_WITH_IM2COL\n" : "",
#include "../../kernels/level3/level3.opencl"
@@ -53,9 +53,6 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
const Buffer<T> &result_buffer, const size_t result_offset) {
- // TODO: Implement single-kernel approach
- assert(method_ == ConvGemmMethod::kWithIm2Col);
-
// Tests for a valid batch count
if (batch_count == 0) {
throw BLASError(StatusCode::kInvalidBatchCount);
@@ -121,7 +118,12 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
}
// Retrieves the proper XgemmDirect kernel from the compiled binary
- auto kernel = Kernel(program_, "Xconvgemm");
+ const std::string kernel_name = (method_ == ConvGemmMethod::kWithIm2Col)
+ ? "Xconvgemm"
+ : (kernel_mode == KernelMode::kConvolution)
+ ? "XconvgemmFlip"
+ : "XconvgemmNormal";
+ auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(num_patches));
diff --git a/src/routines/levelx/xconvgemm.hpp b/src/routines/levelx/xconvgemm.hpp
index 20cfff60..16082fc6 100644
--- a/src/routines/levelx/xconvgemm.hpp
+++ b/src/routines/levelx/xconvgemm.hpp
@@ -29,7 +29,7 @@ class Xconvgemm: public Routine {
// Constructor
enum class ConvGemmMethod {kWithIm2Col, kSingleKernel};
Xconvgemm(Queue &queue, EventPointer event, const std::string &name = "CONVGEMM",
- const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col);
+ const ConvGemmMethod method = ConvGemmMethod::kSingleKernel);
// Templated-precision implementation of the routine
void DoConvgemm(const KernelMode kernel_mode,