diff options
Diffstat (limited to 'src/routines/levelx/xconvgemm.cpp')
-rw-r--r-- | src/routines/levelx/xconvgemm.cpp | 69 |
1 files changed, 21 insertions, 48 deletions
diff --git a/src/routines/levelx/xconvgemm.cpp b/src/routines/levelx/xconvgemm.cpp index f36f72e5..8cb8093c 100644 --- a/src/routines/levelx/xconvgemm.cpp +++ b/src/routines/levelx/xconvgemm.cpp @@ -11,13 +11,12 @@ // // ================================================================================================= -#include "routines/levelx/xconvgemm.hpp" -#include "routines/levelx/xim2col.hpp" -#include "routines/level3/xgemm.hpp" - #include <string> #include <vector> +#include "routines/levelx/xconvgemm.hpp" +#include "routines/levelx/xim2col.hpp" + namespace clblast { // ================================================================================================= @@ -32,7 +31,7 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam #include "../../kernels/level3/xgemm_direct_part2.opencl" #include "../../kernels/level3/xgemm_direct_part3.opencl" , // separated in multiple parts to prevent C1091 in MSVC 2013 - #include "../../kernels/level3/xgemm_direct_batched.opencl" + #include "../../kernels/level3/xconvgemm.opencl" }) { } @@ -93,61 +92,35 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const } // Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result) - const auto m = num_patches; - const auto n = num_kernels; - const auto k = patch_size; - const auto col_ld = m; - const auto kernel_ld = k; - const auto result_ld = m; const auto col_stride = patch_size * num_patches; - const auto kernel_stride = size_t{0}; // applies the same kernel to all batches const auto result_stride = num_kernels * output_h * output_w; - // Computes the transpose/conjugate options and sets the a/b/c sizes based on that - bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate; - size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two; - Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k, - col_one, col_two, kernel_one, kernel_two, result_one, result_two, - col_do_transpose, kernel_do_transpose, - result_do_transpose, col_conjugate, kernel_conjugate, 0); - // Tests the matrices for validity + TestMatrixB(patch_size, num_kernels, kernel_buffer, kernel_offset, patch_size); for (auto batch = size_t{0}; batch < batch_count; ++batch) { - TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld); - TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld); - TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld); + TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches); + TestMatrixC(num_patches, num_kernels, result_buffer, result_offset + result_stride * batch, num_patches); } // Retrieves the proper XgemmDirect kernel from the compiled binary - const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") : - (kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN"); - auto kernel = Kernel(program_, name); + auto kernel = Kernel(program_, "Xconvgemm"); // Sets the kernel arguments - kernel.SetArgument(0, static_cast<int>(m)); - kernel.SetArgument(1, static_cast<int>(n)); - kernel.SetArgument(2, static_cast<int>(k)); - kernel.SetArgument(3, GetRealArg(ConstantOne<T>())); - kernel.SetArgument(4, GetRealArg(ConstantZero<T>())); - kernel.SetArgument(5, col_buffer()); - kernel.SetArgument(6, static_cast<int>(0)); - kernel.SetArgument(7, static_cast<int>(col_ld)); - kernel.SetArgument(8, static_cast<int>(col_stride)); - kernel.SetArgument(9, kernel_buffer()); - kernel.SetArgument(10, static_cast<int>(kernel_offset)); - kernel.SetArgument(11, static_cast<int>(kernel_ld)); - kernel.SetArgument(12, static_cast<int>(kernel_stride)); - kernel.SetArgument(13, result_buffer()); - kernel.SetArgument(14, static_cast<int>(result_offset)); - kernel.SetArgument(15, static_cast<int>(result_ld)); - kernel.SetArgument(16, static_cast<int>(result_stride)); - kernel.SetArgument(17, static_cast<int>(result_do_transpose)); - kernel.SetArgument(18, static_cast<int>(false)); - kernel.SetArgument(19, static_cast<int>(false)); + kernel.SetArgument(0, static_cast<int>(num_patches)); + kernel.SetArgument(1, static_cast<int>(num_kernels)); + kernel.SetArgument(2, static_cast<int>(patch_size)); + kernel.SetArgument(3, col_buffer()); + kernel.SetArgument(4, static_cast<int>(0)); + kernel.SetArgument(5, static_cast<int>(col_stride)); + kernel.SetArgument(6, kernel_buffer()); + kernel.SetArgument(7, static_cast<int>(kernel_offset)); + kernel.SetArgument(8, result_buffer()); + kernel.SetArgument(9, static_cast<int>(result_offset)); + kernel.SetArgument(10, static_cast<int>(result_stride)); // Computes the global and local thread sizes - const auto m_ceiled = Ceil(m, db_["WGD"]); - const auto n_ceiled = Ceil(n, db_["WGD"]); + const auto m_ceiled = Ceil(num_patches, db_["WGD"]); + const auto n_ceiled = Ceil(num_kernels, db_["WGD"]); const auto global = std::vector<size_t>{ (m_ceiled * db_["MDIMCD"]) / db_["WGD"], (n_ceiled * db_["NDIMCD"]) / db_["WGD"], |