summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xconvgemm.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-05-21 11:28:11 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-05-21 11:28:11 +0200
commit5d87abf78080de8e844ff93822da49d2c8a7deb3 (patch)
treec07d850368ae0ecfce2268551414355c98a679bb /src/routines/levelx/xconvgemm.cpp
parent37cabd4f1f144557aa378d944af53a94fc1ff6d1 (diff)
Added method selection option to switch between im2col and single-kernel approach for convgemm
Diffstat (limited to 'src/routines/levelx/xconvgemm.cpp')
-rw-r--r--src/routines/levelx/xconvgemm.cpp105
1 files changed, 60 insertions, 45 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index 055a3dda..5ad39751 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -22,9 +22,11 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
-Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name):
+Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name,
+ const ConvGemmMethod method):
Routine(queue, event, name, {"XgemmDirect"},
PrecisionValue<T>(), {}, {
+ (method == ConvGemmMethod::kWithIm2Col) ? "#define CONVGEMM_WITH_IM2COL\n" : "",
#include "../../kernels/level3/level3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_direct_part1.opencl"
@@ -33,7 +35,8 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/levelx/xconvgemm_part1.opencl"
#include "../../kernels/levelx/xconvgemm_part2.opencl"
- }) {
+ }),
+ method_(method) {
}
// =================================================================================================
@@ -70,26 +73,29 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const auto patch_size = kernel_h * kernel_w * channels;
const auto num_patches = output_h * output_w;
- // Approach: im2col + GEMM
+ // Possible approach: im2col + GEMM
// result = GEMM(im2col(image), kernel)
-
- // Temporary col matrix
- const auto col_size = patch_size * num_patches * batch_count;
- auto col_buffer = Buffer<T>(context_, col_size);
-
- // Loops over each batch
- for (auto batch_id = size_t{0}; batch_id < batch_count; ++batch_id) {
-
- // im2col
- const auto im_batch_offset = batch_id * channels * height * width + im_offset;
- const auto col_batch_offset = batch_id * patch_size * num_patches;
- auto im2col_event = Event();
- auto im2col = Xim2col<T>(queue_, im2col_event.pointer());
- im2col.DoIm2col(channels, height, width, kernel_h, kernel_w,
- pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
- im_buffer, im_batch_offset,
- col_buffer, col_batch_offset);
- im2col_event.WaitForCompletion();
+ auto col_buffer = Buffer<T>(context_, 0); // nullptr, will be optionally created later
+ if (method_ == ConvGemmMethod::kWithIm2Col) {
+
+ // Temporary col matrix
+ const auto col_size = (method_ == ConvGemmMethod::kWithIm2Col) ? patch_size * num_patches * batch_count : 1;
+ col_buffer = Buffer<T>(context_, col_size);
+
+ // Loops over each batch
+ for (auto batch_id = size_t{0}; batch_id < batch_count; ++batch_id) {
+
+ // im2col
+ const auto im_batch_offset = batch_id * channels * height * width + im_offset;
+ const auto col_batch_offset = batch_id * patch_size * num_patches;
+ auto im2col_event = Event();
+ auto im2col = Xim2col<T>(queue_, im2col_event.pointer());
+ im2col.DoIm2col(channels, height, width, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ im_buffer, im_batch_offset,
+ col_buffer, col_batch_offset);
+ im2col_event.WaitForCompletion();
+ }
}
// Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
@@ -99,7 +105,12 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
// Tests the matrices for validity
TestMatrixB(patch_size, num_kernels, kernel_buffer, kernel_offset, patch_size);
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
- TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches);
+ if (method_ == ConvGemmMethod::kWithIm2Col) {
+ TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches);
+ }
+ else {
+ // TODO: check for valid image tensor
+ }
TestMatrixC(num_patches, num_kernels, result_buffer, result_offset + result_stride * batch, num_patches);
}
@@ -110,29 +121,33 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
kernel.SetArgument(0, static_cast<int>(num_patches));
kernel.SetArgument(1, static_cast<int>(num_kernels));
kernel.SetArgument(2, static_cast<int>(patch_size));
- kernel.SetArgument(3, col_buffer());
- kernel.SetArgument(4, static_cast<int>(0));
- kernel.SetArgument(5, static_cast<int>(col_stride));
- kernel.SetArgument(6, kernel_buffer());
- kernel.SetArgument(7, static_cast<int>(kernel_offset));
- kernel.SetArgument(8, result_buffer());
- kernel.SetArgument(9, static_cast<int>(result_offset));
- kernel.SetArgument(10, static_cast<int>(result_stride));
- kernel.SetArgument(11, static_cast<int>(height));
- kernel.SetArgument(12, static_cast<int>(width));
- kernel.SetArgument(13, static_cast<int>(channels));
- kernel.SetArgument(14, static_cast<int>(kernel_h));
- kernel.SetArgument(15, static_cast<int>(kernel_w));
- kernel.SetArgument(16, static_cast<int>(pad_h));
- kernel.SetArgument(17, static_cast<int>(pad_w));
- kernel.SetArgument(18, static_cast<int>(stride_h));
- kernel.SetArgument(19, static_cast<int>(stride_w));
- kernel.SetArgument(20, static_cast<int>(dilation_h));
- kernel.SetArgument(21, static_cast<int>(dilation_w));
- kernel.SetArgument(22, im_buffer());
- kernel.SetArgument(23, static_cast<int>(im_offset));
- kernel.SetArgument(24, static_cast<int>(output_h));
- kernel.SetArgument(25, static_cast<int>(output_w));
+ kernel.SetArgument(3, kernel_buffer());
+ kernel.SetArgument(4, static_cast<int>(kernel_offset));
+ kernel.SetArgument(5, result_buffer());
+ kernel.SetArgument(6, static_cast<int>(result_offset));
+ kernel.SetArgument(7, static_cast<int>(result_stride));
+ if (method_ == ConvGemmMethod::kWithIm2Col) {
+ kernel.SetArgument(8, col_buffer());
+ kernel.SetArgument(9, static_cast<int>(0));
+ kernel.SetArgument(10, static_cast<int>(col_stride));
+ }
+ if (method_ == ConvGemmMethod::kSingleKernel) {
+ kernel.SetArgument(8, im_buffer());
+ kernel.SetArgument(9, static_cast<int>(im_offset));
+ kernel.SetArgument(10, static_cast<int>(height));
+ kernel.SetArgument(11, static_cast<int>(width));
+ kernel.SetArgument(12, static_cast<int>(channels));
+ kernel.SetArgument(13, static_cast<int>(kernel_h));
+ kernel.SetArgument(14, static_cast<int>(kernel_w));
+ kernel.SetArgument(15, static_cast<int>(pad_h));
+ kernel.SetArgument(16, static_cast<int>(pad_w));
+ kernel.SetArgument(17, static_cast<int>(stride_h));
+ kernel.SetArgument(18, static_cast<int>(stride_w));
+ kernel.SetArgument(19, static_cast<int>(dilation_h));
+ kernel.SetArgument(20, static_cast<int>(dilation_w));
+ kernel.SetArgument(21, static_cast<int>(output_h));
+ kernel.SetArgument(22, static_cast<int>(output_w));
+ }
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(num_patches, db_["WGD"]);