summaryrefslogtreecommitdiff
path: root/src/routines/level1/xaxpy.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/level1/xaxpy.cpp')
-rw-r--r--src/routines/level1/xaxpy.cpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/src/routines/level1/xaxpy.cpp b/src/routines/level1/xaxpy.cpp
index 310562a0..0e588d99 100644
--- a/src/routines/level1/xaxpy.cpp
+++ b/src/routines/level1/xaxpy.cpp
@@ -44,18 +44,21 @@ void Xaxpy<T>::DoAxpy(const size_t n, const T alpha,
TestVectorY(n, y_buffer, y_offset, y_inc);
// Determines whether or not the fast-version can be used
- const auto use_fast_kernel = (x_offset == 0) && (x_inc == 1) &&
- (y_offset == 0) && (y_inc == 1) &&
- IsMultiple(n, db_["WGS"]*db_["WPT"]*db_["VW"]);
+ const auto use_faster_kernel = (x_offset == 0) && (x_inc == 1) &&
+ (y_offset == 0) && (y_inc == 1) &&
+ IsMultiple(n, db_["WPT"]*db_["VW"]);
+ const auto use_fastest_kernel = use_faster_kernel &&
+ IsMultiple(n, db_["WGS"]*db_["WPT"]*db_["VW"]);
// If possible, run the fast-version of the kernel
- const auto kernel_name = (use_fast_kernel) ? "XaxpyFast" : "Xaxpy";
+ const auto kernel_name = (use_fastest_kernel) ? "XaxpyFastest" :
+ (use_faster_kernel) ? "XaxpyFaster" : "Xaxpy";
// Retrieves the Xaxpy kernel from the compiled binary
auto kernel = Kernel(program_, kernel_name);
// Sets the kernel arguments
- if (use_fast_kernel) {
+ if (use_faster_kernel || use_fastest_kernel) {
kernel.SetArgument(0, static_cast<int>(n));
kernel.SetArgument(1, GetRealArg(alpha));
kernel.SetArgument(2, x_buffer());
@@ -73,11 +76,16 @@ void Xaxpy<T>::DoAxpy(const size_t n, const T alpha,
}
// Launches the kernel
- if (use_fast_kernel) {
+ if (use_fastest_kernel) {
auto global = std::vector<size_t>{CeilDiv(n, db_["WPT"]*db_["VW"])};
auto local = std::vector<size_t>{db_["WGS"]};
RunKernel(kernel, queue_, device_, global, local, event_);
}
+ else if (use_faster_kernel) {
+ auto global = std::vector<size_t>{Ceil(CeilDiv(n, db_["WPT"]*db_["VW"]), db_["WGS"])};
+ auto local = std::vector<size_t>{db_["WGS"]};
+ RunKernel(kernel, queue_, device_, global, local, event_);
+ }
else {
const auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]);
auto global = std::vector<size_t>{n_ceiled/db_["WPT"]};