diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-04-01 13:36:24 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-04-01 13:36:24 +0200 |
commit | b84d2296b87ac212474af855d916b12adf96bdb7 (patch) | |
tree | 0f2e85e1e1acef1d22f046499dd0b8a30e5da4f9 /test/performance/client.hpp | |
parent | a98c00a2671b8981579f3a73dca8fb3365a95e53 (diff) |
Separated host-device and device-host memory copies from execution of the CBLAS reference code; for fair timing and code de-duplication
Diffstat (limited to 'test/performance/client.hpp')
-rw-r--r-- | test/performance/client.hpp | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/test/performance/client.hpp b/test/performance/client.hpp index b5cc1465..12fd113d 100644 --- a/test/performance/client.hpp +++ b/test/performance/client.hpp @@ -44,12 +44,15 @@ class Client { // Shorthand for the routine-specific functions passed to the tester using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>; + using Reference1 = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>; + using Reference2 = std::function<StatusCode(const Arguments<U>&, BuffersHost<T>&, Queue&)>; using SetMetric = std::function<void(Arguments<U>&)>; using GetMetric = std::function<size_t(const Arguments<U>&)>; // The constructor - Client(const Routine run_routine, const Routine run_reference1, const Routine run_reference2, + Client(const Routine run_routine, const Reference1 run_reference1, const Reference2 run_reference2, const std::vector<std::string> &options, + const std::vector<std::string> &buffers_in, const std::vector<std::string> &buffers_out, const GetMetric get_flops, const GetMetric get_bytes); // Parses all command-line arguments, filling in the arguments structure. If no command-line @@ -66,8 +69,9 @@ class Client { private: // Runs a function a given number of times and returns the execution time of the shortest instance - double TimedExecution(const size_t num_runs, const Arguments<U> &args, Buffers<T> &buffers, - Queue &queue, Routine run_blas, const std::string &library_name); + template <typename BufferType, typename RoutineType> + double TimedExecution(const size_t num_runs, const Arguments<U> &args, BufferType &buffers, + Queue &queue, RoutineType run_blas, const std::string &library_name); // Prints the header of a performance-data table void PrintTableHeader(const Arguments<U>& args); @@ -78,9 +82,11 @@ class Client { // The routine-specific functions passed to the tester const Routine run_routine_; - const Routine run_reference1_; - const Routine run_reference2_; + const Reference1 run_reference1_; + const Reference2 run_reference2_; const std::vector<std::string> options_; + const std::vector<std::string> buffers_in_; + const std::vector<std::string> buffers_out_; const GetMetric get_flops_; const GetMetric get_bytes_; @@ -91,8 +97,8 @@ class Client { // ================================================================================================= // Bogus reference function, in case a comparison library is not available -template <typename T, typename U> -static StatusCode ReferenceNotAvailable(const Arguments<U> &, Buffers<T> &, Queue &) { +template <typename T, typename U, typename BufferType> +static StatusCode ReferenceNotAvailable(const Arguments<U> &, BufferType &, Queue &) { return StatusCode::kNotImplemented; } @@ -105,17 +111,17 @@ void RunClient(int argc, char *argv[]) { #ifdef CLBLAST_REF_CLBLAS auto reference1 = C::RunReference1; // clBLAS when available #else - auto reference1 = ReferenceNotAvailable<T,U>; + auto reference1 = ReferenceNotAvailable<T,U,Buffers<T>>; #endif #ifdef CLBLAST_REF_CBLAS auto reference2 = C::RunReference2; // CBLAS when available #else - auto reference2 = ReferenceNotAvailable<T,U>; + auto reference2 = ReferenceNotAvailable<T,U,BuffersHost<T>>; #endif // Creates a new client auto client = Client<T,U>(C::RunRoutine, reference1, reference2, C::GetOptions(), - C::GetFlops, C::GetBytes); + C::BuffersIn(), C::BuffersOut(), C::GetFlops, C::GetBytes); // Simple command line argument parser with defaults auto args = client.ParseArguments(argc, argv, C::BLASLevel(), |