summaryrefslogtreecommitdiff
path: root/src/routines/levelx
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-05-13 21:01:46 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-05-13 21:01:46 +0200
commitad8f1027abf6fecbf8f119172644e729a8a94d0c (patch)
treeeb7e6e77321e6b8b05424a5bf1ce6693cde48852 /src/routines/levelx
parent4e6d30088d7c73f13b0ad6db5794b232add2b735 (diff)
Plugged in the code of strided-batched-gemm into convgemm in preparation of a new kernel
Diffstat (limited to 'src/routines/levelx')
-rw-r--r--src/routines/levelx/xconvgemm.cpp90
1 files changed, 74 insertions, 16 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp
index d3b198a2..f36f72e5 100644
--- a/src/routines/levelx/xconvgemm.cpp
+++ b/src/routines/levelx/xconvgemm.cpp
@@ -13,7 +13,7 @@
#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xim2col.hpp"
-#include "routines/levelx/xgemmstridedbatched.hpp"
+#include "routines/level3/xgemm.hpp"
#include <string>
#include <vector>
@@ -24,9 +24,16 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name):
- Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
-#include "../../kernels/levelx/im2col.opencl"
- }) {
+ Routine(queue, event, name, {"XgemmDirect"},
+ PrecisionValue<T>(), {}, {
+ #include "../../kernels/level3/level3.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
+ #include "../../kernels/level3/xgemm_direct_part1.opencl"
+ #include "../../kernels/level3/xgemm_direct_part2.opencl"
+ #include "../../kernels/level3/xgemm_direct_part3.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
+ #include "../../kernels/level3/xgemm_direct_batched.opencl"
+ }) {
}
// =================================================================================================
@@ -41,8 +48,13 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
const Buffer<T> &result_buffer, const size_t result_offset) {
+ // Tests for a valid batch count
+ if (batch_count == 0) {
+ throw BLASError(StatusCode::kInvalidBatchCount);
+ }
+
// Makes sure all dimensions are larger than zero
- if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0) || (batch_count == 0)) {
+ if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0)) {
throw BLASError(StatusCode::kInvalidDimension);
}
@@ -80,7 +92,7 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
im2col_event.WaitForCompletion();
}
- // GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
+ // Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
const auto m = num_patches;
const auto n = num_kernels;
const auto k = patch_size;
@@ -88,17 +100,63 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const auto kernel_ld = k;
const auto result_ld = m;
const auto col_stride = patch_size * num_patches;
- const auto kernel_stride = size_t{0}; // applies the same kernel to all
+ const auto kernel_stride = size_t{0}; // applies the same kernel to all batches
const auto result_stride = num_kernels * output_h * output_w;
- auto gemm_event = Event();
- auto gemm = XgemmStridedBatched<T>(queue_, gemm_event.pointer());
- gemm.DoGemmStridedBatched(Layout::kColMajor, Transpose::kNo, Transpose::kNo,
- m, n, k, ConstantOne<T>(),
- col_buffer, 0, col_ld, col_stride,
- kernel_buffer, kernel_offset, kernel_ld, kernel_stride, ConstantZero<T>(),
- result_buffer, result_offset, result_ld, result_stride,
- batch_count);
- gemm_event.WaitForCompletion();
+
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate;
+ size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two;
+ Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k,
+ col_one, col_two, kernel_one, kernel_two, result_one, result_two,
+ col_do_transpose, kernel_do_transpose,
+ result_do_transpose, col_conjugate, kernel_conjugate, 0);
+
+ // Tests the matrices for validity
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld);
+ TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld);
+ TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld);
+ }
+
+ // Retrieves the proper XgemmDirect kernel from the compiled binary
+ const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") :
+ (kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN");
+ auto kernel = Kernel(program_, name);
+
+ // Sets the kernel arguments
+ kernel.SetArgument(0, static_cast<int>(m));
+ kernel.SetArgument(1, static_cast<int>(n));
+ kernel.SetArgument(2, static_cast<int>(k));
+ kernel.SetArgument(3, GetRealArg(ConstantOne<T>()));
+ kernel.SetArgument(4, GetRealArg(ConstantZero<T>()));
+ kernel.SetArgument(5, col_buffer());
+ kernel.SetArgument(6, static_cast<int>(0));
+ kernel.SetArgument(7, static_cast<int>(col_ld));
+ kernel.SetArgument(8, static_cast<int>(col_stride));
+ kernel.SetArgument(9, kernel_buffer());
+ kernel.SetArgument(10, static_cast<int>(kernel_offset));
+ kernel.SetArgument(11, static_cast<int>(kernel_ld));
+ kernel.SetArgument(12, static_cast<int>(kernel_stride));
+ kernel.SetArgument(13, result_buffer());
+ kernel.SetArgument(14, static_cast<int>(result_offset));
+ kernel.SetArgument(15, static_cast<int>(result_ld));
+ kernel.SetArgument(16, static_cast<int>(result_stride));
+ kernel.SetArgument(17, static_cast<int>(result_do_transpose));
+ kernel.SetArgument(18, static_cast<int>(false));
+ kernel.SetArgument(19, static_cast<int>(false));
+
+ // Computes the global and local thread sizes
+ const auto m_ceiled = Ceil(m, db_["WGD"]);
+ const auto n_ceiled = Ceil(n, db_["WGD"]);
+ const auto global = std::vector<size_t>{
+ (m_ceiled * db_["MDIMCD"]) / db_["WGD"],
+ (n_ceiled * db_["NDIMCD"]) / db_["WGD"],
+ batch_count
+ };
+ const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};
+
+ // Launches the kernel
+ RunKernel(kernel, queue_, device_, global, local, event_);
}
// =================================================================================================