summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-06-14 11:15:53 +0200
committerCNugteren <web@cedricnugteren.nl>2015-06-14 11:15:53 +0200
commit294a3e3d410c87ffcc7fc550e09b6d45c71a0af8 (patch)
treed68a45bb8312aabba9589bb1c51b2c6ffe0dc504 /src/routines
parentab0064dab76c83ee9820acb62fa914c493c2563d (diff)
Split the three variations of the GEMV kernel for maximal tuning freedom
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/xgemv.cc36
1 files changed, 26 insertions, 10 deletions
diff --git a/src/routines/xgemv.cc b/src/routines/xgemv.cc
index 74851ec9..9f3908f8 100644
--- a/src/routines/xgemv.cc
+++ b/src/routines/xgemv.cc
@@ -70,13 +70,30 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose,
if (ErrorIn(status)) { return status; }
// Determines whether or not the fast-version can be used
- bool use_fast_kernel = (a_offset == 0) &&
- IsMultiple(m, db_["WGS"]*db_["WPT"]) &&
- IsMultiple(n, db_["WGS"]) &&
- IsMultiple(a_ld, db_["VW"]);
-
- // If possible, run the fast-version of the kernel
- auto kernel_name = (use_fast_kernel) ? "XgemvFast" : "Xgemv";
+ bool use_fast_kernel = (a_offset == 0) && (a_rotated == 0) &&
+ IsMultiple(m, db_["WGS2"]*db_["WPT2"]) &&
+ IsMultiple(n, db_["WGS2"]) &&
+ IsMultiple(a_ld, db_["VW2"]);
+ bool use_fast_kernel_rot = (a_offset == 0) && (a_rotated == 1) &&
+ IsMultiple(m, db_["WGS3"]*db_["WPT3"]) &&
+ IsMultiple(n, db_["WGS3"]) &&
+ IsMultiple(a_ld, db_["VW3"]);
+
+ // If possible, run the fast-version (rotated or non-rotated) of the kernel
+ auto kernel_name = "Xgemv";
+ auto m_ceiled = Ceil(m_real, db_["WGS1"]*db_["WPT1"]);
+ auto global_size = m_ceiled / db_["WPT1"];
+ auto local_size = db_["WGS1"];
+ if (use_fast_kernel) {
+ kernel_name = "XgemvFast";
+ global_size = m_real / db_["WPT2"];
+ local_size = db_["WGS2"];
+ }
+ if (use_fast_kernel_rot) {
+ kernel_name = "XgemvFastRot";
+ global_size = m_real / db_["WPT3"];
+ local_size = db_["WGS3"];
+ }
// Retrieves the Xgemv kernel from the compiled binary
try {
@@ -100,9 +117,8 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose,
kernel.SetArgument(13, static_cast<int>(y_inc));
// Launches the kernel
- auto m_ceiled = Ceil(m_real, db_["WGS"]*db_["WPT"]);
- auto global = std::vector<size_t>{m_ceiled / db_["WPT"]};
- auto local = std::vector<size_t>{db_["WGS"]};
+ auto global = std::vector<size_t>{global_size};
+ auto local = std::vector<size_t>{local_size};
status = RunKernel(kernel, global, local);
if (ErrorIn(status)) { return status; }