diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-08 20:36:35 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-08 20:36:35 +0100 |
commit | 878d93e7dc508f53495139ab2e18c71fffdab1fd (patch) | |
tree | eb1f17e49f52126ebf2a07d21d28c10df8ab7e9a /src/routines/levelx/xaxpybatched.cpp | |
parent | fa0a9c689fc21a2a24aeadf82ae0acdf6d8bf831 (diff) |
Implemented a batched version of the AXPY kernel
Diffstat (limited to 'src/routines/levelx/xaxpybatched.cpp')
-rw-r--r-- | src/routines/levelx/xaxpybatched.cpp | 37 |
1 files changed, 17 insertions, 20 deletions
diff --git a/src/routines/levelx/xaxpybatched.cpp b/src/routines/levelx/xaxpybatched.cpp index 8089cdc6..6a4269be 100644 --- a/src/routines/levelx/xaxpybatched.cpp +++ b/src/routines/levelx/xaxpybatched.cpp @@ -57,32 +57,29 @@ void XaxpyBatched<T>::DoAxpyBatched(const size_t n, const std::vector<T> &alphas std::vector<int> y_offsets_int(y_offsets.begin(), y_offsets.end()); auto x_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count); auto y_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count); + auto alphas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count); x_offsets_device.Write(queue_, batch_count, x_offsets_int); y_offsets_device.Write(queue_, batch_count, y_offsets_int); + alphas_device.Write(queue_, batch_count, alphas); // Retrieves the Xaxpy kernel from the compiled binary auto kernel = Kernel(program_, "XaxpyBatched"); - // Naive implementation: calls regular Axpy multiple times - for (auto batch = size_t{0}; batch < batch_count; ++batch) { - - // Sets the kernel arguments - kernel.SetArgument(0, static_cast<int>(n)); - kernel.SetArgument(1, GetRealArg(alphas[batch])); - kernel.SetArgument(2, x_buffer()); - kernel.SetArgument(3, static_cast<int>(x_offsets[batch])); - kernel.SetArgument(4, static_cast<int>(x_inc)); - kernel.SetArgument(5, y_buffer()); - kernel.SetArgument(6, static_cast<int>(y_offsets[batch])); - kernel.SetArgument(7, static_cast<int>(y_inc)); - kernel.SetArgument(8, static_cast<int>(batch)); - - // Launches the kernel - auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]); - auto global = std::vector<size_t>{n_ceiled/db_["WPT"]}; - auto local = std::vector<size_t>{db_["WGS"]}; - RunKernel(kernel, queue_, device_, global, local, event_); - } + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(n)); + kernel.SetArgument(1, alphas_device()); + kernel.SetArgument(2, x_buffer()); + kernel.SetArgument(3, x_offsets_device()); + kernel.SetArgument(4, static_cast<int>(x_inc)); + kernel.SetArgument(5, y_buffer()); + kernel.SetArgument(6, y_offsets_device()); + kernel.SetArgument(7, static_cast<int>(y_inc)); + + // Launches the kernel + auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]); + auto global = std::vector<size_t>{n_ceiled/db_["WPT"], batch_count}; + auto local = std::vector<size_t>{db_["WGS"], 1}; + RunKernel(kernel, queue_, device_, global, local, event_); } // ================================================================================================= |