summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/level3/xgemm.cpp105
-rw-r--r--src/routines/level3/xgemm.hpp23
2 files changed, 127 insertions, 1 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 0b8e768f..9d912374 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -22,7 +22,9 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
- Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm"}, PrecisionValue<T>()) {
+ Routine(queue, event, name,
+ {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
+ PrecisionValue<T>()) {
source_string_ =
#include "../../kernels/level3/level3.opencl"
#include "../../kernels/level3/copy_fast.opencl"
@@ -35,6 +37,9 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
#include "../../kernels/level3/xgemm_part3.opencl"
+ #include "../../kernels/level3/xgemm_direct_part1.opencl"
+ #include "../../kernels/level3/xgemm_direct_part2.opencl"
+ #include "../../kernels/level3/xgemm_direct_part3.opencl"
;
}
@@ -98,6 +103,44 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
status = TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld);
if (ErrorIn(status)) { return status; }
+ // Selects which version of GEMM to run
+ const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]);
+ if (do_gemm_direct) { // for small sizes (single kernel)
+ return GemmDirect(m, n, k, alpha,
+ a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
+ c_buffer, c_offset, c_ld,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
+ }
+ else { // for larger sizes (pre/post-processing plus a very fast kernel)
+ return GemmIndirect(m, n, k, alpha,
+ a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
+ c_buffer, c_offset, c_ld,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
+ a_one, a_two, a_want_rotated,
+ b_one, b_two, b_want_rotated,
+ c_one, c_two, c_want_rotated);
+ }
+}
+
+// =================================================================================================
+
+// The indirect version of GEMM. This uses the faster but non-general kernel. It has specific
+// requirements, but several pre and post-processing kernels take care of those. However, the
+// overhead of these extra kernels might not be ideal for certain devices/arguments.
+template <typename T>
+StatusCode Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
+ const T beta,
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
+ const bool a_conjugate, const bool b_conjugate,
+ const size_t a_one, const size_t a_two, const bool a_want_rotated,
+ const size_t b_one, const size_t b_two, const bool b_want_rotated,
+ const size_t c_one, const size_t c_two, const bool c_want_rotated) {
+ auto status = StatusCode::kSuccess;
+
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(m, db_["MWG"]);
const auto n_ceiled = Ceil(n, db_["NWG"]);
@@ -217,6 +260,66 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
} catch (...) { return StatusCode::kTempBufferAllocFailure; }
}
+
+// =================================================================================================
+
+// The direct version of GEMM, requiring just one kernel, no pre or post-processing kernels.
+template <typename T>
+StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
+ const T beta,
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
+ const bool a_conjugate, const bool b_conjugate) {
+
+ // Loads the program from the database
+ const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
+
+ // Retrieves the proper XgemmDirect kernel from the compiled binary
+ try {
+ const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") :
+ (b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN");
+ auto kernel = Kernel(program, name);
+
+ // Sets the kernel arguments
+ kernel.SetArgument(0, static_cast<int>(m));
+ kernel.SetArgument(1, static_cast<int>(n));
+ kernel.SetArgument(2, static_cast<int>(k));
+ kernel.SetArgument(3, GetRealArg(alpha));
+ kernel.SetArgument(4, GetRealArg(beta));
+ kernel.SetArgument(5, a_buffer());
+ kernel.SetArgument(6, static_cast<int>(a_offset));
+ kernel.SetArgument(7, static_cast<int>(a_ld));
+ kernel.SetArgument(8, b_buffer());
+ kernel.SetArgument(9, static_cast<int>(b_offset));
+ kernel.SetArgument(10, static_cast<int>(b_ld));
+ kernel.SetArgument(11, c_buffer());
+ kernel.SetArgument(12, static_cast<int>(c_offset));
+ kernel.SetArgument(13, static_cast<int>(c_ld));
+ kernel.SetArgument(14, static_cast<int>(c_do_transpose));
+ kernel.SetArgument(15, static_cast<int>(a_conjugate));
+ kernel.SetArgument(16, static_cast<int>(b_conjugate));
+
+ // Computes the global and local thread sizes
+ const auto m_ceiled = Ceil(m, db_["WGD"]);
+ const auto n_ceiled = Ceil(n, db_["WGD"]);
+ const auto global = std::vector<size_t>{
+ (m_ceiled * db_["MDIMCD"]) / db_["WGD"],
+ (n_ceiled * db_["NDIMCD"]) / db_["WGD"]
+ };
+ const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"]};
+
+ // Launches the kernel
+ auto status = RunKernel(kernel, queue_, device_, global, local, event_);
+ if (ErrorIn(status)) { return status; }
+
+ // Successfully finished the computation
+ return StatusCode::kSuccess;
+ } catch (...) { return StatusCode::kInvalidKernel; }
+}
+
// =================================================================================================
// Compiles the templated class
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index bc51c7f5..46e12453 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -35,6 +35,29 @@ class Xgemm: public Routine {
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld);
+
+ // Indirect version of GEMM (with pre and post-processing kernels)
+ StatusCode GemmIndirect(const size_t m, const size_t n, const size_t k,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
+ const T beta,
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
+ const bool a_conjugate, const bool b_conjugate,
+ const size_t a_one, const size_t a_two, const bool a_want_rotated,
+ const size_t b_one, const size_t b_two, const bool b_want_rotated,
+ const size_t c_one, const size_t c_two, const bool c_want_rotated);
+
+ // Direct version of GEMM (no pre and post-processing kernels)
+ StatusCode GemmDirect(const size_t m, const size_t n, const size_t k,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
+ const T beta,
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
+ const bool a_conjugate, const bool b_conjugate);
};
// =================================================================================================