summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xconvgemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/levelx/xconvgemm.cpp')
-rw-r--r--src/routines/levelx/xconvgemm.cpp69
1 files changed, 21 insertions, 48 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index f36f72e5..8cb8093c 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -11,13 +11,12 @@
//
// =================================================================================================
-#include "routines/levelx/xconvgemm.hpp"
-#include "routines/levelx/xim2col.hpp"
-#include "routines/level3/xgemm.hpp"
-
#include <string>
#include <vector>
+#include "routines/levelx/xconvgemm.hpp"
+#include "routines/levelx/xim2col.hpp"
+
namespace clblast {
// =================================================================================================
@@ -32,7 +31,7 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam
#include "../../kernels/level3/xgemm_direct_part2.opencl"
#include "../../kernels/level3/xgemm_direct_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
- #include "../../kernels/level3/xgemm_direct_batched.opencl"
+ #include "../../kernels/level3/xconvgemm.opencl"
}) {
}
@@ -93,61 +92,35 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
}
// Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
- const auto m = num_patches;
- const auto n = num_kernels;
- const auto k = patch_size;
- const auto col_ld = m;
- const auto kernel_ld = k;
- const auto result_ld = m;
const auto col_stride = patch_size * num_patches;
- const auto kernel_stride = size_t{0}; // applies the same kernel to all batches
const auto result_stride = num_kernels * output_h * output_w;
- // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
- bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate;
- size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two;
- Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k,
- col_one, col_two, kernel_one, kernel_two, result_one, result_two,
- col_do_transpose, kernel_do_transpose,
- result_do_transpose, col_conjugate, kernel_conjugate, 0);
-
// Tests the matrices for validity
+ TestMatrixB(patch_size, num_kernels, kernel_buffer, kernel_offset, patch_size);
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
- TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld);
- TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld);
- TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld);
+ TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches);
+ TestMatrixC(num_patches, num_kernels, result_buffer, result_offset + result_stride * batch, num_patches);
}
// Retrieves the proper XgemmDirect kernel from the compiled binary
- const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") :
- (kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN");
- auto kernel = Kernel(program_, name);
+ auto kernel = Kernel(program_, "Xconvgemm");
// Sets the kernel arguments
- kernel.SetArgument(0, static_cast<int>(m));
- kernel.SetArgument(1, static_cast<int>(n));
- kernel.SetArgument(2, static_cast<int>(k));
- kernel.SetArgument(3, GetRealArg(ConstantOne<T>()));
- kernel.SetArgument(4, GetRealArg(ConstantZero<T>()));
- kernel.SetArgument(5, col_buffer());
- kernel.SetArgument(6, static_cast<int>(0));
- kernel.SetArgument(7, static_cast<int>(col_ld));
- kernel.SetArgument(8, static_cast<int>(col_stride));
- kernel.SetArgument(9, kernel_buffer());
- kernel.SetArgument(10, static_cast<int>(kernel_offset));
- kernel.SetArgument(11, static_cast<int>(kernel_ld));
- kernel.SetArgument(12, static_cast<int>(kernel_stride));
- kernel.SetArgument(13, result_buffer());
- kernel.SetArgument(14, static_cast<int>(result_offset));
- kernel.SetArgument(15, static_cast<int>(result_ld));
- kernel.SetArgument(16, static_cast<int>(result_stride));
- kernel.SetArgument(17, static_cast<int>(result_do_transpose));
- kernel.SetArgument(18, static_cast<int>(false));
- kernel.SetArgument(19, static_cast<int>(false));
+ kernel.SetArgument(0, static_cast<int>(num_patches));
+ kernel.SetArgument(1, static_cast<int>(num_kernels));
+ kernel.SetArgument(2, static_cast<int>(patch_size));
+ kernel.SetArgument(3, col_buffer());
+ kernel.SetArgument(4, static_cast<int>(0));
+ kernel.SetArgument(5, static_cast<int>(col_stride));
+ kernel.SetArgument(6, kernel_buffer());
+ kernel.SetArgument(7, static_cast<int>(kernel_offset));
+ kernel.SetArgument(8, result_buffer());
+ kernel.SetArgument(9, static_cast<int>(result_offset));
+ kernel.SetArgument(10, static_cast<int>(result_stride));
// Computes the global and local thread sizes
- const auto m_ceiled = Ceil(m, db_["WGD"]);
- const auto n_ceiled = Ceil(n, db_["WGD"]);
+ const auto m_ceiled = Ceil(num_patches, db_["WGD"]);
+ const auto n_ceiled = Ceil(num_kernels, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],