summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xaxpybatched.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-08 20:36:35 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-03-08 20:36:35 +0100
commit878d93e7dc508f53495139ab2e18c71fffdab1fd (patch)
treeeb1f17e49f52126ebf2a07d21d28c10df8ab7e9a /src/routines/levelx/xaxpybatched.cpp
parentfa0a9c689fc21a2a24aeadf82ae0acdf6d8bf831 (diff)
Implemented a batched version of the AXPY kernel
Diffstat (limited to 'src/routines/levelx/xaxpybatched.cpp')
-rw-r--r--src/routines/levelx/xaxpybatched.cpp37
1 files changed, 17 insertions, 20 deletions
diff --git a/src/routines/levelx/xaxpybatched.cpp b/src/routines/levelx/xaxpybatched.cpp
index 8089cdc6..6a4269be 100644
--- a/src/routines/levelx/xaxpybatched.cpp
+++ b/src/routines/levelx/xaxpybatched.cpp
@@ -57,32 +57,29 @@ void XaxpyBatched<T>::DoAxpyBatched(const size_t n, const std::vector<T> &alphas
std::vector<int> y_offsets_int(y_offsets.begin(), y_offsets.end());
auto x_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
auto y_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
+ auto alphas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count);
x_offsets_device.Write(queue_, batch_count, x_offsets_int);
y_offsets_device.Write(queue_, batch_count, y_offsets_int);
+ alphas_device.Write(queue_, batch_count, alphas);
// Retrieves the Xaxpy kernel from the compiled binary
auto kernel = Kernel(program_, "XaxpyBatched");
- // Naive implementation: calls regular Axpy multiple times
- for (auto batch = size_t{0}; batch < batch_count; ++batch) {
-
- // Sets the kernel arguments
- kernel.SetArgument(0, static_cast<int>(n));
- kernel.SetArgument(1, GetRealArg(alphas[batch]));
- kernel.SetArgument(2, x_buffer());
- kernel.SetArgument(3, static_cast<int>(x_offsets[batch]));
- kernel.SetArgument(4, static_cast<int>(x_inc));
- kernel.SetArgument(5, y_buffer());
- kernel.SetArgument(6, static_cast<int>(y_offsets[batch]));
- kernel.SetArgument(7, static_cast<int>(y_inc));
- kernel.SetArgument(8, static_cast<int>(batch));
-
- // Launches the kernel
- auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]);
- auto global = std::vector<size_t>{n_ceiled/db_["WPT"]};
- auto local = std::vector<size_t>{db_["WGS"]};
- RunKernel(kernel, queue_, device_, global, local, event_);
- }
+ // Sets the kernel arguments
+ kernel.SetArgument(0, static_cast<int>(n));
+ kernel.SetArgument(1, alphas_device());
+ kernel.SetArgument(2, x_buffer());
+ kernel.SetArgument(3, x_offsets_device());
+ kernel.SetArgument(4, static_cast<int>(x_inc));
+ kernel.SetArgument(5, y_buffer());
+ kernel.SetArgument(6, y_offsets_device());
+ kernel.SetArgument(7, static_cast<int>(y_inc));
+
+ // Launches the kernel
+ auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]);
+ auto global = std::vector<size_t>{n_ceiled/db_["WPT"], batch_count};
+ auto local = std::vector<size_t>{db_["WGS"], 1};
+ RunKernel(kernel, queue_, device_, global, local, event_);
}
// =================================================================================================