summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-08 20:10:20 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-03-08 20:10:20 +0100
commitfa0a9c689fc21a2a24aeadf82ae0acdf6d8bf831 (patch)
tree404e85900a4c9038d407addb38798d06bb48868c /src/clblast.cpp
parent6aba0bbae71702c4eebd88d0fe17739b509185c1 (diff)
Make batched routines based on offsets instead of a vector of cl_mem objects - undoing many earlier changes
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp36
1 files changed, 18 insertions, 18 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index e9cac664..d3db8edf 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2178,57 +2178,57 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
template <typename T>
StatusCode AxpyBatched(const size_t n,
const T *alphas,
- const cl_mem *x_buffers, const size_t x_inc,
- cl_mem *y_buffers, const size_t y_inc,
+ const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc,
+ cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = XaxpyBatched<T>(queue_cpp, event);
auto alphas_cpp = std::vector<T>();
- auto x_buffers_cpp = std::vector<Buffer<T>>();
- auto y_buffers_cpp = std::vector<Buffer<T>>();
+ auto x_offsets_cpp = std::vector<size_t>();
+ auto y_offsets_cpp = std::vector<size_t>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(alphas[batch]);
- x_buffers_cpp.push_back(Buffer<T>(x_buffers[batch]));
- y_buffers_cpp.push_back(Buffer<T>(y_buffers[batch]));
+ x_offsets_cpp.push_back(x_offsets[batch]);
+ y_offsets_cpp.push_back(y_offsets[batch]);
}
routine.DoAxpyBatched(n,
alphas_cpp,
- x_buffers_cpp, x_inc,
- y_buffers_cpp, y_inc,
+ Buffer<T>(x_buffer), x_offsets_cpp, x_inc,
+ Buffer<T>(y_buffer), y_offsets_cpp, y_inc,
batch_count);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API AxpyBatched<float>(const size_t,
const float*,
- const cl_mem*, const size_t,
- cl_mem*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<double>(const size_t,
const double*,
- const cl_mem*, const size_t,
- cl_mem*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<float2>(const size_t,
const float2*,
- const cl_mem*, const size_t,
- cl_mem*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<double2>(const size_t,
const double2*,
- const cl_mem*, const size_t,
- cl_mem*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
const half*,
- const cl_mem*, const size_t,
- cl_mem*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
// =================================================================================================