summaryrefslogtreecommitdiff
path: root/test/performance/client.h
diff options
context:
space:
mode:
Diffstat (limited to 'test/performance/client.h')
-rw-r--r--test/performance/client.h104
1 files changed, 58 insertions, 46 deletions
diff --git a/test/performance/client.h b/test/performance/client.h
index 097ae048..f9f219d0 100644
--- a/test/performance/client.h
+++ b/test/performance/client.h
@@ -7,7 +7,12 @@
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
-// This file provides common function declarations to be used with the test clients.
+// This class implements the performance-test client. It is generic for all CLBlast routines by
+// taking a number of routine-specific functions as arguments, such as how to compute buffer sizes
+// or how to get the FLOPS count.
+//
+// This file also provides the common interface to the performance client (see the 'RunClient'
+// function for details).
//
// =================================================================================================
@@ -26,64 +31,71 @@
namespace clblast {
// =================================================================================================
-// Types of devices to consider
-const cl_device_type kDeviceType = CL_DEVICE_TYPE_ALL;
+// See comment at top of file for a description of the class
+template <typename T>
+class Client {
+ public:
-// =================================================================================================
+ // Types of devices to consider
+ const cl_device_type kDeviceType = CL_DEVICE_TYPE_ALL;
-// Shorthand for a BLAS routine with 2 or 3 OpenCL buffers as argument
-template <typename T>
-using Routine2 = std::function<void(const Arguments<T>&,
- const Buffer&, const Buffer&,
- CommandQueue&)>;
-template <typename T>
-using Routine3 = std::function<void(const Arguments<T>&,
- const Buffer&, const Buffer&, const Buffer&,
- CommandQueue&)>;
+ // Shorthand for the routine-specific functions passed to the tester
+ using Routine = std::function<StatusCode(const Arguments<T>&, const Buffers&, CommandQueue&)>;
+ using SetMetric = std::function<void(Arguments<T>&)>;
+ using GetMetric = std::function<size_t(const Arguments<T>&)>;
-// =================================================================================================
+ // The constructor
+ Client(const Routine run_routine, const Routine run_reference,
+ const std::vector<std::string> &options,
+ const GetMetric get_flops, const GetMetric get_bytes);
-// These are the main client functions, setting-up arguments, matrices, OpenCL buffers, etc. After
-// set-up, they call the client routine, passed as argument to this function.
-template <typename T>
-void ClientXY(int argc, char *argv[], Routine2<T> client_routine,
- const std::vector<std::string> &options);
-template <typename T>
-void ClientAXY(int argc, char *argv[], Routine3<T> client_routine,
- const std::vector<std::string> &options);
-template <typename T>
-void ClientAC(int argc, char *argv[], Routine2<T> client_routine,
- const std::vector<std::string> &options);
-template <typename T>
-void ClientABC(int argc, char *argv[], Routine3<T> client_routine,
- const std::vector<std::string> &options, const bool symmetric);
+ // Parses all command-line arguments, filling in the arguments structure. If no command-line
+ // argument is given for a particular argument, it is filled in with a default value.
+ Arguments<T> ParseArguments(int argc, char *argv[], const GetMetric default_a_ld,
+ const GetMetric default_b_ld, const GetMetric default_c_ld);
-// =================================================================================================
+ // The main client function, setting-up arguments, matrices, OpenCL buffers, etc. After set-up, it
+ // calls the client routines.
+ void PerformanceTest(Arguments<T> &args, const SetMetric set_sizes);
-// Parses all command-line arguments, filling in the arguments structure. If no command-line
-// argument is given for a particular argument, it is filled in with a default value.
-template <typename T>
-Arguments<T> ParseArguments(int argc, char *argv[], const std::vector<std::string> &options,
- const std::function<size_t(const Arguments<T>)> default_ld_a);
+ private:
-// Retrieves only the precision command-line argument, since the above function is templated based
-// on the precision
-Precision GetPrecision(int argc, char *argv[]);
+ // Runs a function a given number of times and returns the execution time of the shortest instance
+ double TimedExecution(const size_t num_runs, const Arguments<T> &args, const Buffers &buffers,
+ CommandQueue &queue, Routine run_blas, const std::string &library_name);
-// =================================================================================================
+ // Prints the header of a performance-data table
+ void PrintTableHeader(const bool silent, const std::vector<std::string> &args);
+
+ // Prints a row of performance data, including results of two libraries
+ void PrintTableRow(const Arguments<T>& args, const double ms_clblast, const double ms_clblas);
-// Runs a function a given number of times and returns the execution time of the shortest instance
-double TimedExecution(const size_t num_runs, std::function<void()> main_computation);
+ // The routine-specific functions passed to the tester
+ const Routine run_routine_;
+ const Routine run_reference_;
+ const std::vector<std::string> options_;
+ const GetMetric get_flops_;
+ const GetMetric get_bytes_;
+};
// =================================================================================================
-// Prints the header of a performance-data table
-void PrintTableHeader(const bool silent, const std::vector<std::string> &args);
+// The interface to the performance client. This is a separate function in the header such that it
+// is automatically compiled for each routine, templated by the parameter "C".
+template <typename C, typename T>
+void RunClient(int argc, char *argv[]) {
+
+ // Creates a new client
+ auto client = Client<T>(C::RunRoutine, C::RunReference, C::GetOptions(),
+ C::GetFlops, C::GetBytes);
+
+ // Simple command line argument parser with defaults
+ auto args = client.ParseArguments(argc, argv, C::DefaultLDA, C::DefaultLDB, C::DefaultLDC);
+ if (args.print_help) { return; }
-// Prints a row of performance data, including results of two libraries
-void PrintTableRow(const std::vector<size_t> &args_int, const std::vector<std::string> &args_string,
- const bool abbreviations, const double ms_clblast, const double ms_clblas,
- const unsigned long long flops, const unsigned long long bytes);
+ // Runs the client
+ client.PerformanceTest(args, C::SetSizes);
+}
// =================================================================================================
} // namespace clblast