diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/kernels/level1/xaxpy.opencl | 14 | ||||
-rw-r--r-- | src/routines/levelx/xaxpybatched.cpp | 37 |
2 files changed, 24 insertions, 27 deletions
diff --git a/src/kernels/level1/xaxpy.opencl b/src/kernels/level1/xaxpy.opencl index 0d730c9e..3f5ab2b5 100644 --- a/src/kernels/level1/xaxpy.opencl +++ b/src/kernels/level1/xaxpy.opencl @@ -57,17 +57,17 @@ void XaxpyFast(const int n, const real_arg arg_alpha, // Full version of the kernel with offsets and strided accesses: batched version __kernel __attribute__((reqd_work_group_size(WGS, 1, 1))) -void XaxpyBatched(const int n, const real_arg arg_alpha, - const __global real* restrict xgm, const int x_offset, const int x_inc, - __global real* ygm, const int y_offset, const int y_inc, - const int batch) { - const real alpha = GetRealArg(arg_alpha); +void XaxpyBatched(const int n, const __global real_arg* arg_alphas, + const __global real* restrict xgm, const __global int* restrict x_offsets, const int x_inc, + __global real* ygm, const __global int* restrict y_offsets, const int y_inc) { + const int batch = get_group_id(1); + const real alpha = GetRealArg(arg_alphas[batch]); // Loops over the work that needs to be done (allows for an arbitrary number of threads) #pragma unroll for (int id = get_global_id(0); id<n; id += get_global_size(0)) { - real xvalue = xgm[id*x_inc + x_offset]; - MultiplyAdd(ygm[id*y_inc + y_offset], alpha, xvalue); + real xvalue = xgm[id*x_inc + x_offsets[batch]]; + MultiplyAdd(ygm[id*y_inc + y_offsets[batch]], alpha, xvalue); } } diff --git a/src/routines/levelx/xaxpybatched.cpp b/src/routines/levelx/xaxpybatched.cpp index 8089cdc6..6a4269be 100644 --- a/src/routines/levelx/xaxpybatched.cpp +++ b/src/routines/levelx/xaxpybatched.cpp @@ -57,32 +57,29 @@ void XaxpyBatched<T>::DoAxpyBatched(const size_t n, const std::vector<T> &alphas std::vector<int> y_offsets_int(y_offsets.begin(), y_offsets.end()); auto x_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count); auto y_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count); + auto alphas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count); x_offsets_device.Write(queue_, batch_count, x_offsets_int); y_offsets_device.Write(queue_, batch_count, y_offsets_int); + alphas_device.Write(queue_, batch_count, alphas); // Retrieves the Xaxpy kernel from the compiled binary auto kernel = Kernel(program_, "XaxpyBatched"); - // Naive implementation: calls regular Axpy multiple times - for (auto batch = size_t{0}; batch < batch_count; ++batch) { - - // Sets the kernel arguments - kernel.SetArgument(0, static_cast<int>(n)); - kernel.SetArgument(1, GetRealArg(alphas[batch])); - kernel.SetArgument(2, x_buffer()); - kernel.SetArgument(3, static_cast<int>(x_offsets[batch])); - kernel.SetArgument(4, static_cast<int>(x_inc)); - kernel.SetArgument(5, y_buffer()); - kernel.SetArgument(6, static_cast<int>(y_offsets[batch])); - kernel.SetArgument(7, static_cast<int>(y_inc)); - kernel.SetArgument(8, static_cast<int>(batch)); - - // Launches the kernel - auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]); - auto global = std::vector<size_t>{n_ceiled/db_["WPT"]}; - auto local = std::vector<size_t>{db_["WGS"]}; - RunKernel(kernel, queue_, device_, global, local, event_); - } + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(n)); + kernel.SetArgument(1, alphas_device()); + kernel.SetArgument(2, x_buffer()); + kernel.SetArgument(3, x_offsets_device()); + kernel.SetArgument(4, static_cast<int>(x_inc)); + kernel.SetArgument(5, y_buffer()); + kernel.SetArgument(6, y_offsets_device()); + kernel.SetArgument(7, static_cast<int>(y_inc)); + + // Launches the kernel + auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]); + auto global = std::vector<size_t>{n_ceiled/db_["WPT"], batch_count}; + auto local = std::vector<size_t>{db_["WGS"], 1}; + RunKernel(kernel, queue_, device_, global, local, event_); } // ================================================================================================= |