summaryrefslogtreecommitdiff
path: root/test/routines/levelx/xaxpybatched.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/routines/levelx/xaxpybatched.hpp')
-rw-r--r--test/routines/levelx/xaxpybatched.hpp19
1 files changed, 7 insertions, 12 deletions
diff --git a/test/routines/levelx/xaxpybatched.hpp b/test/routines/levelx/xaxpybatched.hpp
index 8f6a5985..ee15ff92 100644
--- a/test/routines/levelx/xaxpybatched.hpp
+++ b/test/routines/levelx/xaxpybatched.hpp
@@ -46,11 +46,6 @@ class TestXaxpyBatched {
kArgBatchCount, kArgAlpha};
}
- // Helper to determine a different alpha value per batch
- static T GetAlpha(const T alpha_base, const size_t batch_id) {
- return alpha_base + Constant<T>(batch_id);
- }
-
// Helper for the sizes per batch
static size_t PerBatchSizeX(const Arguments<T> &args) { return args.n * args.x_inc; }
static size_t PerBatchSizeY(const Arguments<T> &args) { return args.n * args.y_inc; }
@@ -67,11 +62,15 @@ class TestXaxpyBatched {
static void SetSizes(Arguments<T> &args) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
+
+ // Also sets the batch-related variables
args.x_offsets = std::vector<size_t>(args.batch_count);
args.y_offsets = std::vector<size_t>(args.batch_count);
+ args.alphas = std::vector<T>(args.batch_count);
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
args.x_offsets[batch] = batch * PerBatchSizeX(args) + args.x_offset;
args.y_offsets[batch] = batch * PerBatchSizeY(args) + args.y_offset;
+ args.alphas[batch] = args.alpha + Constant<T>(batch);
}
}
@@ -94,11 +93,7 @@ class TestXaxpyBatched {
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();
auto event = cl_event{};
- auto alphas = std::vector<T>();
- for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
- alphas.push_back(GetAlpha(args.alpha, batch));
- }
- auto status = AxpyBatched(args.n, alphas.data(),
+ auto status = AxpyBatched(args.n, args.alphas.data(),
buffers.x_vec(), args.x_offsets.data(), args.x_inc,
buffers.y_vec(), args.y_offsets.data(), args.y_inc,
args.batch_count,
@@ -113,7 +108,7 @@ class TestXaxpyBatched {
auto queue_plain = queue();
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
auto event = cl_event{};
- auto status = clblasXaxpy(args.n, GetAlpha(args.alpha, batch),
+ auto status = clblasXaxpy(args.n, args.alphas[batch],
buffers.x_vec, args.x_offsets[batch], args.x_inc,
buffers.y_vec, args.y_offsets[batch], args.y_inc,
1, &queue_plain, 0, nullptr, &event);
@@ -134,7 +129,7 @@ class TestXaxpyBatched {
buffers.x_vec.Read(queue, args.x_size, x_vec_cpu);
buffers.y_vec.Read(queue, args.y_size, y_vec_cpu);
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
- cblasXaxpy(args.n, GetAlpha(args.alpha, batch),
+ cblasXaxpy(args.n, args.alphas[batch],
x_vec_cpu, args.x_offsets[batch], args.x_inc,
y_vec_cpu, args.y_offsets[batch], args.y_inc);
}