summaryrefslogtreecommitdiff
path: root/src/utilities/utilities.hpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-10 21:15:29 +0100
committerGitHub <noreply@github.com>2017-03-10 21:15:29 +0100
commitde3500ed18ddb39261ffa270f460909571276462 (patch)
treeb515368fcd1e39afb5805f67796b082ccc8066f9 /src/utilities/utilities.hpp
parent37228c90988509acef9e8a892a752300b7645210 (diff)
parent3846f44eaf389ee24a698d4947e5c16bd14c3d0e (diff)
Merge pull request #141 from CNugteren/axpy_batched
Added the batched version of the AXPY routine
Diffstat (limited to 'src/utilities/utilities.hpp')
-rw-r--r--src/utilities/utilities.hpp14
1 files changed, 13 insertions, 1 deletions
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index 3c9be6a2..b3db8c22 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -20,6 +20,7 @@
#include <string>
#include <functional>
#include <complex>
+#include <random>
#include "clpp11.hpp"
#include "clblast.h"
@@ -72,6 +73,7 @@ constexpr auto kArgAsumOffset = "offasum";
constexpr auto kArgImaxOffset = "offimax";
constexpr auto kArgAlpha = "alpha";
constexpr auto kArgBeta = "beta";
+constexpr auto kArgBatchCount = "batch_num";
// The tuner-specific arguments in string form
constexpr auto kArgFraction = "fraction";
@@ -155,6 +157,16 @@ struct Arguments {
size_t imax_offset = 0;
T alpha = ConstantOne<T>();
T beta = ConstantOne<T>();
+ // Batch-specific arguments
+ size_t batch_count = 1;
+ std::vector<size_t> x_offsets = {0};
+ std::vector<size_t> y_offsets = {0};
+ std::vector<size_t> a_offsets = {0};
+ std::vector<size_t> b_offsets = {0};
+ std::vector<size_t> c_offsets = {0};
+ std::vector<T> alphas = {ConstantOne<T>()};
+ std::vector<T> betas = {ConstantOne<T>()};
+ // Sizes
size_t x_size = 1;
size_t y_size = 1;
size_t a_size = 1;
@@ -234,7 +246,7 @@ constexpr auto kTestDataUpperLimit = 2.0;
// Populates a vector with random data
template <typename T>
-void PopulateVector(std::vector<T> &vector, const unsigned int seed);
+void PopulateVector(std::vector<T> &vector, std::mt19937 &mt, std::uniform_real_distribution<double> &dist);
// =================================================================================================