diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/kernels/common.opencl | 2 | ||||
-rw-r--r-- | src/kernels/level3/xgemm_direct.opencl | 71 | ||||
-rw-r--r-- | src/routines/level3/xgemm.cpp | 94 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 23 |
4 files changed, 189 insertions, 1 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index 9d2bb65e..2fca6b73 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -197,7 +197,7 @@ R"( #if PRECISION == 3232 || PRECISION == 6464 #define COMPLEX_CONJUGATE(value) value.x = value.x; value.y = -value.y #else - #define COMPLEX_CONJUGATE(value) value = value + #define COMPLEX_CONJUGATE(value) #endif // ================================================================================================= diff --git a/src/kernels/level3/xgemm_direct.opencl b/src/kernels/level3/xgemm_direct.opencl new file mode 100644 index 00000000..9d2a55c8 --- /dev/null +++ b/src/kernels/level3/xgemm_direct.opencl @@ -0,0 +1,71 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This is a generic GEMM kernel that works for all sizes and configurations: it doesn't require any +// pre and and post-processing kernels. +// +// ================================================================================================= + +// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string +// literal). Comment-out this line for syntax-highlighting when developing. +R"( + +// ================================================================================================= + +// Main entry point of the kernel. This is the direct version. +__attribute__((reqd_work_group_size(16, 16, 1))) +__kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK, + const real_arg arg_alpha, + const real_arg arg_beta, + const __global real* restrict agm, const int a_offset, const int a_ld, + const __global real* restrict bgm, const int b_offset, const int b_ld, + __global real* cgm, const int c_offset, const int c_ld, + const int a_transpose, const int b_transpose, const int c_transpose, + const int a_conjugate, const int b_conjugate) { + const real alpha = GetRealArg(arg_alpha); + const real beta = GetRealArg(arg_beta); + + // Thread identifiers + const int mid = get_global_id(0); // Row ID of cgm + const int nid = get_global_id(1); // Col ID of cgm + + // Allows for incomplete workgroups + if (mid < kSizeM && nid < kSizeN) { + + // Computes a single element + real acc; + SetToZero(acc); + for (int k=0; k<kSizeK; ++k) { + const int a_index = (a_transpose) ? mid*a_ld + k : k*a_ld + mid; + const int b_index = (b_transpose) ? nid*b_ld + k : k*b_ld + nid; + real a_val = agm[a_index + a_offset]; + real b_val = bgm[b_index + b_offset]; + if (a_conjugate) { COMPLEX_CONJUGATE(a_val); } + if (b_conjugate) { COMPLEX_CONJUGATE(b_val); } + MultiplyAdd(acc, a_val, b_val); + } + + // Determines the destination index + const int c_index = (c_transpose) ? mid*c_ld + nid : nid*c_ld + mid; + + // Computes the result + real result; + AXPBY(result, alpha, acc, beta, cgm[c_index + c_offset]); + + // Stores the result + cgm[c_index + c_offset] = result; + } +} + +// ================================================================================================= + +// End of the C++11 raw string literal +)" + +// ================================================================================================= diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 0db28537..4bdf3192 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -34,6 +34,7 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): #include "../../kernels/level3/convert_hermitian.opencl" #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + #include "../../kernels/level3/xgemm_direct.opencl" ; } @@ -94,6 +95,42 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, status = TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld); if (ErrorIn(status)) { return status; } + // Optionally runs the direct version of GEMM. TODO: Set this based on the arguments + const auto do_gemm_direct = true; // for now, for testing + if (do_gemm_direct) { + return GemmDirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + } + else { + return GemmIndirect(m, n, k, alpha, + a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, + c_buffer, c_offset, c_ld, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, + a_one, a_two, b_one, b_two, c_one, c_two); + } +} + +// ================================================================================================= + +// The indirect version of GEMM. This uses the faster but non-general kernel. It has specific +// requirements, but several pre and post-processing kernels take care of those. However, the +// overhead of these extra kernels might not be ideal for certain devices/arguments. +template <typename T> +StatusCode Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate, + const size_t a_one, const size_t a_two, + const size_t b_one, const size_t b_two, + const size_t c_one, const size_t c_two) { + auto status = StatusCode::kSuccess; + // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(m, db_["MWG"]); const auto n_ceiled = Ceil(n, db_["NWG"]); @@ -204,6 +241,63 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, } catch (...) { return StatusCode::kTempBufferAllocFailure; } } + +// ================================================================================================= + +// The direct version of GEMM, requiring just one kernel, no pre or post-processing kernels. +template <typename T> +StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate) { + + // Loads the program from the database + const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_); + + // Retrieves the XgemmDirect kernel from the compiled binary + try { + auto kernel = Kernel(program, "XgemmDirect"); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(m)); + kernel.SetArgument(1, static_cast<int>(n)); + kernel.SetArgument(2, static_cast<int>(k)); + kernel.SetArgument(3, GetRealArg(alpha)); + kernel.SetArgument(4, GetRealArg(beta)); + kernel.SetArgument(5, a_buffer()); + kernel.SetArgument(6, static_cast<int>(a_offset)); + kernel.SetArgument(7, static_cast<int>(a_ld)); + kernel.SetArgument(8, b_buffer()); + kernel.SetArgument(9, static_cast<int>(b_offset)); + kernel.SetArgument(10, static_cast<int>(b_ld)); + kernel.SetArgument(11, c_buffer()); + kernel.SetArgument(12, static_cast<int>(c_offset)); + kernel.SetArgument(13, static_cast<int>(c_ld)); + kernel.SetArgument(14, static_cast<int>(a_do_transpose)); + kernel.SetArgument(15, static_cast<int>(b_do_transpose)); + kernel.SetArgument(16, static_cast<int>(c_do_transpose)); + kernel.SetArgument(17, static_cast<int>(a_conjugate)); + kernel.SetArgument(18, static_cast<int>(b_conjugate)); + + // Computes the global and local thread sizes + const auto m_ceiled = Ceil(m, 16); + const auto n_ceiled = Ceil(n, 16); + const auto global = std::vector<size_t>{m_ceiled, n_ceiled}; + const auto local = std::vector<size_t>{16, 16}; + + // Launches the kernel + auto status = RunKernel(kernel, queue_, device_, global, local, event_); + if (ErrorIn(status)) { return status; } + + // Successfully finished the computation + return StatusCode::kSuccess; + } catch (...) { return StatusCode::kInvalidKernel; } +} + // ================================================================================================= // Compiles the templated class diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index bc51c7f5..8db1cb11 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -35,6 +35,29 @@ class Xgemm: public Routine { const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld); + + // Indirect version of GEMM (with pre and post-processing kernels) + StatusCode GemmIndirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate, + const size_t a_one, const size_t a_two, + const size_t b_one, const size_t b_two, + const size_t c_one, const size_t c_two); + + // Direct version of GEMM (no pre and post-processing kernels) + StatusCode GemmDirect(const size_t m, const size_t n, const size_t k, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, + const bool a_conjugate, const bool b_conjugate); }; // ================================================================================================= |