From 13f0f6fc6e612a5f77c6fd78b983f1b2bb8e36b6 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 7 Jan 2018 14:58:45 +0100 Subject: Implemented direct version of strided-batched GEMM kernel --- src/routines/levelx/xgemmstridedbatched.cpp | 36 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) (limited to 'src/routines/levelx') diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index 3ea52980..ddf7d878 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -112,6 +112,7 @@ void XgemmStridedBatched::BatchedGemmIndirect(const size_t m, const size_t n, const size_t b_one, const size_t b_two, const size_t c_one, const size_t c_two, const size_t batch_count) { + /* TODO // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]); const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]); @@ -123,7 +124,6 @@ void XgemmStridedBatched::BatchedGemmIndirect(const size_t m, const size_t n, Xgemm::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); - /* TODO // Sets the "internal" offsets, i.e. the perfect offsets auto a_offsets_i = 0;//std::vector(batch_count); auto b_offsets_i = 0;//std::vector(batch_count); @@ -244,30 +244,33 @@ void XgemmStridedBatched::BatchedGemmDirect(const size_t m, const size_t n, c const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate, const size_t batch_count) { -/* TODO + // Retrieves the proper XgemmDirect kernel from the compiled binary - const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectBatchedTT" : "XgemmDirectBatchedTN") : - (b_do_transpose ? "XgemmDirectBatchedNT" : "XgemmDirectBatchedNN"); + const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") : + (b_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN"); auto kernel = Kernel(program_, name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(m)); kernel.SetArgument(1, static_cast(n)); kernel.SetArgument(2, static_cast(k)); - kernel.SetArgument(3, alpha); - kernel.SetArgument(4, beta); + kernel.SetArgument(3, GetRealArg(alpha)); + kernel.SetArgument(4, GetRealArg(beta)); kernel.SetArgument(5, a_buffer()); - kernel.SetArgument(6, a_offset); + kernel.SetArgument(6, static_cast(a_offset)); kernel.SetArgument(7, static_cast(a_ld)); - kernel.SetArgument(8, b_buffer()); - kernel.SetArgument(9, b_offset); - kernel.SetArgument(10, static_cast(b_ld)); - kernel.SetArgument(11, c_buffer()); - kernel.SetArgument(12, c_offset); - kernel.SetArgument(13, static_cast(c_ld)); - kernel.SetArgument(14, static_cast(c_do_transpose)); - kernel.SetArgument(15, static_cast(a_conjugate)); - kernel.SetArgument(16, static_cast(b_conjugate)); + kernel.SetArgument(8, static_cast(a_stride)); + kernel.SetArgument(9, b_buffer()); + kernel.SetArgument(10, static_cast(b_offset)); + kernel.SetArgument(11, static_cast(b_ld)); + kernel.SetArgument(12, static_cast(b_stride)); + kernel.SetArgument(13, c_buffer()); + kernel.SetArgument(14, static_cast(c_offset)); + kernel.SetArgument(15, static_cast(c_ld)); + kernel.SetArgument(16, static_cast(c_stride)); + kernel.SetArgument(17, static_cast(c_do_transpose)); + kernel.SetArgument(18, static_cast(a_conjugate)); + kernel.SetArgument(19, static_cast(b_conjugate)); // Computes the global and local thread sizes const auto m_ceiled = Ceil(m, db_["WGD"]); @@ -281,7 +284,6 @@ void XgemmStridedBatched::BatchedGemmDirect(const size_t m, const size_t n, c // Launches the kernel RunKernel(kernel, queue_, device_, global, local, event_); - */ } // ================================================================================================= -- cgit v1.2.3