summaryrefslogtreecommitdiff
path: root/test/performance/client.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/performance/client.hpp')
-rw-r--r--test/performance/client.hpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/test/performance/client.hpp b/test/performance/client.hpp
index 12fd113d..47a13017 100644
--- a/test/performance/client.hpp
+++ b/test/performance/client.hpp
@@ -31,6 +31,7 @@
#ifdef CLBLAST_REF_CLBLAS
#include <clBLAS.h>
#endif
+#include "test/wrapper_cuda.hpp"
#include "clblast.h"
namespace clblast {
@@ -46,12 +47,13 @@ class Client {
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
using Reference1 = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
using Reference2 = std::function<StatusCode(const Arguments<U>&, BuffersHost<T>&, Queue&)>;
+ using Reference3 = std::function<StatusCode(const Arguments<U>&, BuffersCUDA<T>&, Queue&)>;
using SetMetric = std::function<void(Arguments<U>&)>;
using GetMetric = std::function<size_t(const Arguments<U>&)>;
// The constructor
Client(const Routine run_routine, const Reference1 run_reference1, const Reference2 run_reference2,
- const std::vector<std::string> &options,
+ const Reference3 run_reference3, const std::vector<std::string> &options,
const std::vector<std::string> &buffers_in, const std::vector<std::string> &buffers_out,
const GetMetric get_flops, const GetMetric get_bytes);
@@ -84,6 +86,7 @@ class Client {
const Routine run_routine_;
const Reference1 run_reference1_;
const Reference2 run_reference2_;
+ const Reference3 run_reference3_;
const std::vector<std::string> options_;
const std::vector<std::string> buffers_in_;
const std::vector<std::string> buffers_out_;
@@ -118,9 +121,14 @@ void RunClient(int argc, char *argv[]) {
#else
auto reference2 = ReferenceNotAvailable<T,U,BuffersHost<T>>;
#endif
+ #ifdef CLBLAST_REF_CUBLAS
+ auto reference3 = C::RunReference3; // cuBLAS when available
+ #else
+ auto reference3 = ReferenceNotAvailable<T,U,BuffersCUDA<T>>;
+ #endif
// Creates a new client
- auto client = Client<T,U>(C::RunRoutine, reference1, reference2, C::GetOptions(),
+ auto client = Client<T,U>(C::RunRoutine, reference1, reference2, reference3, C::GetOptions(),
C::BuffersIn(), C::BuffersOut(), C::GetFlops, C::GetBytes);
// Simple command line argument parser with defaults