From 7b8f8fce6808f2095a68afe97256db7a78f819fa Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 11 Mar 2017 16:02:45 +0100 Subject: Added initial naive version of the batched GEMM routine based on the direct GEMM kernel --- src/routines/level3/xgemm.cpp | 22 ++++++------ src/routines/levelx/xgemmbatched.cpp | 69 +++++++++++++++++++++++++++--------- 2 files changed, 64 insertions(+), 27 deletions(-) (limited to 'src/routines') diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index dc8c64bc..658b22d0 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -104,19 +104,19 @@ void Xgemm::DoGemm(const Layout layout, // Selects which version of GEMM to run const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]); if (do_gemm_direct) { // for small sizes (single kernel) - return GemmDirect(m, n, k, alpha, - a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, - c_buffer, c_offset, c_ld, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + GemmDirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); } else { // for larger sizes (pre/post-processing plus a very fast kernel) - return GemmIndirect(m, n, k, alpha, - a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, - c_buffer, c_offset, c_ld, - a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - a_one, a_two, a_want_rotated, - b_one, b_two, b_want_rotated, - c_one, c_two, c_want_rotated); + GemmIndirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + a_one, a_two, a_want_rotated, + b_one, b_two, b_want_rotated, + c_one, c_two, c_want_rotated); } } diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index b07425d5..a11ebfd0 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -22,25 +22,12 @@ namespace clblast { // Constructor: forwards to base class constructor template XgemmBatched::XgemmBatched(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, - {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"}, - PrecisionValue(), {}, { - #include "../../kernels/level3/level3.opencl" - #include "../../kernels/level3/copy_fast.opencl" - #include "../../kernels/level3/copy_pad.opencl" - #include "../../kernels/level3/transpose_fast.opencl" - #include "../../kernels/level3/transpose_pad.opencl" - #include "../../kernels/level3/convert_symmetric.opencl" - #include "../../kernels/level3/convert_triangular.opencl" - #include "../../kernels/level3/convert_hermitian.opencl" - , // separated in multiple parts to prevent C1091 in MSVC 2013 + Routine(queue, event, name, {"XgemmDirect"}, PrecisionValue(), {}, { #include "../../kernels/level3/xgemm_direct_part1.opencl" #include "../../kernels/level3/xgemm_direct_part2.opencl" #include "../../kernels/level3/xgemm_direct_part3.opencl" , // separated in multiple parts to prevent C1091 in MSVC 2013 - #include "../../kernels/level3/xgemm_part1.opencl" - #include "../../kernels/level3/xgemm_part2.opencl" - #include "../../kernels/level3/xgemm_part3.opencl" + #include "../../kernels/level3/xgemm_direct_batched.opencl" }) { } @@ -99,7 +86,57 @@ void XgemmBatched::DoGemmBatched(const Layout layout, const Transpose a_trans TestMatrixC(c_one, c_two, c_buffer, c_offsets[batch], c_ld); } - // StatusCode::kNotImplemented; + // Upload the arguments to the device + std::vector a_offsets_int(a_offsets.begin(), a_offsets.end()); + std::vector b_offsets_int(b_offsets.begin(), b_offsets.end()); + std::vector c_offsets_int(c_offsets.begin(), c_offsets.end()); + auto a_offsets_device = Buffer(context_, BufferAccess::kReadOnly, batch_count); + auto b_offsets_device = Buffer(context_, BufferAccess::kReadOnly, batch_count); + auto c_offsets_device = Buffer(context_, BufferAccess::kReadOnly, batch_count); + auto alphas_device = Buffer(context_, BufferAccess::kReadOnly, batch_count); + auto betas_device = Buffer(context_, BufferAccess::kReadOnly, batch_count); + a_offsets_device.Write(queue_, batch_count, a_offsets_int); + b_offsets_device.Write(queue_, batch_count, b_offsets_int); + c_offsets_device.Write(queue_, batch_count, c_offsets_int); + alphas_device.Write(queue_, batch_count, alphas); + betas_device.Write(queue_, batch_count, betas); + + // Retrieves the proper XgemmDirect kernel from the compiled binary + const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectBatchedTT" : "XgemmDirectBatchedTN") : + (b_do_transpose ? "XgemmDirectBatchedNT" : "XgemmDirectBatchedNN"); + auto kernel = Kernel(program_, name); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast(m)); + kernel.SetArgument(1, static_cast(n)); + kernel.SetArgument(2, static_cast(k)); + kernel.SetArgument(3, alphas_device()); + kernel.SetArgument(4, betas_device()); + kernel.SetArgument(5, a_buffer()); + kernel.SetArgument(6, a_offsets_device()); + kernel.SetArgument(7, static_cast(a_ld)); + kernel.SetArgument(8, b_buffer()); + kernel.SetArgument(9, b_offsets_device()); + kernel.SetArgument(10, static_cast(b_ld)); + kernel.SetArgument(11, c_buffer()); + kernel.SetArgument(12, c_offsets_device()); + kernel.SetArgument(13, static_cast(c_ld)); + kernel.SetArgument(14, static_cast(c_do_transpose)); + kernel.SetArgument(15, static_cast(a_conjugate)); + kernel.SetArgument(16, static_cast(b_conjugate)); + + // Computes the global and local thread sizes + const auto m_ceiled = Ceil(m, db_["WGD"]); + const auto n_ceiled = Ceil(n, db_["WGD"]); + const auto global = std::vector{ + (m_ceiled * db_["MDIMCD"]) / db_["WGD"], + (n_ceiled * db_["NDIMCD"]) / db_["WGD"], + batch_count + }; + const auto local = std::vector{db_["MDIMCD"], db_["NDIMCD"], 1}; + + // Launches the kernel + RunKernel(kernel, queue_, device_, global, local, event_); } // ================================================================================================= -- cgit v1.2.3