diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-07-16 15:18:28 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-07-16 15:18:28 +0200 |
commit | eaa348735ee5cee396f9ec629f1486ebb3dbeff7 (patch) | |
tree | bcdae90d95f259b7ec0d3d31da7520d775e3ccef /src/routines | |
parent | b33bec4a59d9d4d0b2e6a3d7e5f1d6e23d4279cb (diff) |
Created infrastructure to support a direct GEMM kernel; added correct but slow reference kernel as a place-holder
Diffstat (limited to 'src/routines')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 94 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 23 |
2 files changed, 117 insertions, 0 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 0db28537..4bdf3192 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -34,6 +34,7 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): #include "../../kernels/level3/convert_hermitian.opencl" #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + #include "../../kernels/level3/xgemm_direct.opencl" ; } @@ -94,6 +95,42 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, status = TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld); if (ErrorIn(status)) { return status; } + // Optionally runs the direct version of GEMM. TODO: Set this based on the arguments + const auto do_gemm_direct = true; // for now, for testing + if (do_gemm_direct) { + return GemmDirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + } + else { + return GemmIndirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + a_one, a_two, b_one, b_two, c_one, c_two); + } +} + +// ================================================================================================= + +// The indirect version of GEMM. This uses the faster but non-general kernel. It has specific +// requirements, but several pre and post-processing kernels take care of those. However, the +// overhead of these extra kernels might not be ideal for certain devices/arguments. +template <typename T> +StatusCode Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate, + const size_t a_one, const size_t a_two, + const size_t b_one, const size_t b_two, + const size_t c_one, const size_t c_two) { + auto status = StatusCode::kSuccess; + // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(m, db_["MWG"]); const auto n_ceiled = Ceil(n, db_["NWG"]); @@ -204,6 +241,63 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, } catch (...) { return StatusCode::kTempBufferAllocFailure; } } + +// ================================================================================================= + +// The direct version of GEMM, requiring just one kernel, no pre or post-processing kernels. +template <typename T> +StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate) { + + // Loads the program from the database + const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_); + + // Retrieves the XgemmDirect kernel from the compiled binary + try { + auto kernel = Kernel(program, "XgemmDirect"); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(m)); + kernel.SetArgument(1, static_cast<int>(n)); + kernel.SetArgument(2, static_cast<int>(k)); + kernel.SetArgument(3, GetRealArg(alpha)); + kernel.SetArgument(4, GetRealArg(beta)); + kernel.SetArgument(5, a_buffer()); + kernel.SetArgument(6, static_cast<int>(a_offset)); + kernel.SetArgument(7, static_cast<int>(a_ld)); + kernel.SetArgument(8, b_buffer()); + kernel.SetArgument(9, static_cast<int>(b_offset)); + kernel.SetArgument(10, static_cast<int>(b_ld)); + kernel.SetArgument(11, c_buffer()); + kernel.SetArgument(12, static_cast<int>(c_offset)); + kernel.SetArgument(13, static_cast<int>(c_ld)); + kernel.SetArgument(14, static_cast<int>(a_do_transpose)); + kernel.SetArgument(15, static_cast<int>(b_do_transpose)); + kernel.SetArgument(16, static_cast<int>(c_do_transpose)); + kernel.SetArgument(17, static_cast<int>(a_conjugate)); + kernel.SetArgument(18, static_cast<int>(b_conjugate)); + + // Computes the global and local thread sizes + const auto m_ceiled = Ceil(m, 16); + const auto n_ceiled = Ceil(n, 16); + const auto global = std::vector<size_t>{m_ceiled, n_ceiled}; + const auto local = std::vector<size_t>{16, 16}; + + // Launches the kernel + auto status = RunKernel(kernel, queue_, device_, global, local, event_); + if (ErrorIn(status)) { return status; } + + // Successfully finished the computation + return StatusCode::kSuccess; + } catch (...) { return StatusCode::kInvalidKernel; } +} + // ================================================================================================= // Compiles the templated class diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index bc51c7f5..8db1cb11 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -35,6 +35,29 @@ class Xgemm: public Routine { const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld); + + // Indirect version of GEMM (with pre and post-processing kernels) + StatusCode GemmIndirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate, + const size_t a_one, const size_t a_two, + const size_t b_one, const size_t b_two, + const size_t c_one, const size_t c_two); + + // Direct version of GEMM (no pre and post-processing kernels) + StatusCode GemmDirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate); }; // ================================================================================================= |