diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-05-13 21:01:46 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-05-13 21:01:46 +0200 |
commit | ad8f1027abf6fecbf8f119172644e729a8a94d0c (patch) | |
tree | eb7e6e77321e6b8b05424a5bf1ce6693cde48852 /src/routines/levelx | |
parent | 4e6d30088d7c73f13b0ad6db5794b232add2b735 (diff) |
Plugged in the code of strided-batched-gemm into convgemm in preparation of a new kernel
Diffstat (limited to 'src/routines/levelx')
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 90 |
1 files changed, 74 insertions, 16 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index d3b198a2..f36f72e5 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -13,7 +13,7 @@ #include "routines/levelx/xconvgemm.hpp" #include "routines/levelx/xim2col.hpp" -#include "routines/levelx/xgemmstridedbatched.hpp" +#include "routines/level3/xgemm.hpp" #include <string> #include <vector> @@ -24,9 +24,16 @@ namespace clblast { // Constructor: forwards to base class constructor template <typename T> Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, { -#include "../../kernels/levelx/im2col.opencl" - }) { + Routine(queue, event, name, {"XgemmDirect"}, + PrecisionValue<T>(), {}, { + #include "../../kernels/level3/level3.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 + #include "../../kernels/level3/xgemm_direct_part1.opencl" + #include "../../kernels/level3/xgemm_direct_part2.opencl" + #include "../../kernels/level3/xgemm_direct_part3.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 + #include "../../kernels/level3/xgemm_direct_batched.opencl" + }) { } // ================================================================================================= @@ -41,8 +48,13 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const const Buffer<T> &kernel_buffer, const size_t kernel_offset, const Buffer<T> &result_buffer, const size_t result_offset) { + // Tests for a valid batch count + if (batch_count == 0) { + throw BLASError(StatusCode::kInvalidBatchCount); + } + // Makes sure all dimensions are larger than zero - if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0) || (batch_count == 0)) { + if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0)) { throw BLASError(StatusCode::kInvalidDimension); } @@ -80,7 +92,7 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const im2col_event.WaitForCompletion(); } - // GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result) + // Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result) const auto m = num_patches; const auto n = num_kernels; const auto k = patch_size; @@ -88,17 +100,63 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const const auto kernel_ld = k; const auto result_ld = m; const auto col_stride = patch_size * num_patches; - const auto kernel_stride = size_t{0}; // applies the same kernel to all + const auto kernel_stride = size_t{0}; // applies the same kernel to all batches const auto result_stride = num_kernels * output_h * output_w; - auto gemm_event = Event(); - auto gemm = XgemmStridedBatched<T>(queue_, gemm_event.pointer()); - gemm.DoGemmStridedBatched(Layout::kColMajor, Transpose::kNo, Transpose::kNo, - m, n, k, ConstantOne<T>(), - col_buffer, 0, col_ld, col_stride, - kernel_buffer, kernel_offset, kernel_ld, kernel_stride, ConstantZero<T>(), - result_buffer, result_offset, result_ld, result_stride, - batch_count); - gemm_event.WaitForCompletion(); + + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that + bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate; + size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two; + Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k, + col_one, col_two, kernel_one, kernel_two, result_one, result_two, + col_do_transpose, kernel_do_transpose, + result_do_transpose, col_conjugate, kernel_conjugate, 0); + + // Tests the matrices for validity + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld); + TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld); + TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld); + } + + // Retrieves the proper XgemmDirect kernel from the compiled binary + const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") : + (kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN"); + auto kernel = Kernel(program_, name); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(m)); + kernel.SetArgument(1, static_cast<int>(n)); + kernel.SetArgument(2, static_cast<int>(k)); + kernel.SetArgument(3, GetRealArg(ConstantOne<T>())); + kernel.SetArgument(4, GetRealArg(ConstantZero<T>())); + kernel.SetArgument(5, col_buffer()); + kernel.SetArgument(6, static_cast<int>(0)); + kernel.SetArgument(7, static_cast<int>(col_ld)); + kernel.SetArgument(8, static_cast<int>(col_stride)); + kernel.SetArgument(9, kernel_buffer()); + kernel.SetArgument(10, static_cast<int>(kernel_offset)); + kernel.SetArgument(11, static_cast<int>(kernel_ld)); + kernel.SetArgument(12, static_cast<int>(kernel_stride)); + kernel.SetArgument(13, result_buffer()); + kernel.SetArgument(14, static_cast<int>(result_offset)); + kernel.SetArgument(15, static_cast<int>(result_ld)); + kernel.SetArgument(16, static_cast<int>(result_stride)); + kernel.SetArgument(17, static_cast<int>(result_do_transpose)); + kernel.SetArgument(18, static_cast<int>(false)); + kernel.SetArgument(19, static_cast<int>(false)); + + // Computes the global and local thread sizes + const auto m_ceiled = Ceil(m, db_["WGD"]); + const auto n_ceiled = Ceil(n, db_["WGD"]); + const auto global = std::vector<size_t>{ + (m_ceiled * db_["MDIMCD"]) / db_["WGD"], + (n_ceiled * db_["NDIMCD"]) / db_["WGD"], + batch_count + }; + const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1}; + + // Launches the kernel + RunKernel(kernel, queue_, device_, global, local, event_); } // ================================================================================================= |