diff options
Diffstat (limited to 'src/routines/level1/xaxpy.cpp')
-rw-r--r-- | src/routines/level1/xaxpy.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/routines/level1/xaxpy.cpp b/src/routines/level1/xaxpy.cpp index 310562a0..0e588d99 100644 --- a/src/routines/level1/xaxpy.cpp +++ b/src/routines/level1/xaxpy.cpp @@ -44,18 +44,21 @@ void Xaxpy<T>::DoAxpy(const size_t n, const T alpha, TestVectorY(n, y_buffer, y_offset, y_inc); // Determines whether or not the fast-version can be used - const auto use_fast_kernel = (x_offset == 0) && (x_inc == 1) && - (y_offset == 0) && (y_inc == 1) && - IsMultiple(n, db_["WGS"]*db_["WPT"]*db_["VW"]); + const auto use_faster_kernel = (x_offset == 0) && (x_inc == 1) && + (y_offset == 0) && (y_inc == 1) && + IsMultiple(n, db_["WPT"]*db_["VW"]); + const auto use_fastest_kernel = use_faster_kernel && + IsMultiple(n, db_["WGS"]*db_["WPT"]*db_["VW"]); // If possible, run the fast-version of the kernel - const auto kernel_name = (use_fast_kernel) ? "XaxpyFast" : "Xaxpy"; + const auto kernel_name = (use_fastest_kernel) ? "XaxpyFastest" : + (use_faster_kernel) ? "XaxpyFaster" : "Xaxpy"; // Retrieves the Xaxpy kernel from the compiled binary auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments - if (use_fast_kernel) { + if (use_faster_kernel || use_fastest_kernel) { kernel.SetArgument(0, static_cast<int>(n)); kernel.SetArgument(1, GetRealArg(alpha)); kernel.SetArgument(2, x_buffer()); @@ -73,11 +76,16 @@ void Xaxpy<T>::DoAxpy(const size_t n, const T alpha, } // Launches the kernel - if (use_fast_kernel) { + if (use_fastest_kernel) { auto global = std::vector<size_t>{CeilDiv(n, db_["WPT"]*db_["VW"])}; auto local = std::vector<size_t>{db_["WGS"]}; RunKernel(kernel, queue_, device_, global, local, event_); } + else if (use_faster_kernel) { + auto global = std::vector<size_t>{Ceil(CeilDiv(n, db_["WPT"]*db_["VW"]), db_["WGS"])}; + auto local = std::vector<size_t>{db_["WGS"]}; + RunKernel(kernel, queue_, device_, global, local, event_); + } else { const auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]); auto global = std::vector<size_t>{n_ceiled/db_["WPT"]}; |