summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-07-10 07:18:12 +0200
committerCNugteren <web@cedricnugteren.nl>2015-07-10 07:18:12 +0200
commit2fe3fe15801f8ef11b38bfd93d7d68fbb37253a1 (patch)
treeb2dc073d3c053debc0cb6132165b5d03d0b9e26a
parent5578d5ab282d63ad47a767dcbebb94b83195230d (diff)
The clients now distinguish between the memory and alpha/beta data-type
-rw-r--r--test/performance/client.cc52
-rw-r--r--test/performance/client.h24
-rw-r--r--test/performance/routines/xaxpy.cc12
-rw-r--r--test/performance/routines/xgemm.cc12
-rw-r--r--test/performance/routines/xgemv.cc12
-rw-r--r--test/performance/routines/xsymm.cc12
-rw-r--r--test/performance/routines/xsyr2k.cc12
-rw-r--r--test/performance/routines/xsyrk.cc12
-rw-r--r--test/performance/routines/xtrmm.cc12
9 files changed, 96 insertions, 64 deletions
diff --git a/test/performance/client.cc b/test/performance/client.cc
index fad0f3a9..676e88e4 100644
--- a/test/performance/client.cc
+++ b/test/performance/client.cc
@@ -22,10 +22,10 @@ namespace clblast {
// =================================================================================================
// Constructor
-template <typename T>
-Client<T>::Client(const Routine run_routine, const Routine run_reference,
- const std::vector<std::string> &options,
- const GetMetric get_flops, const GetMetric get_bytes):
+template <typename T, typename U>
+Client<T,U>::Client(const Routine run_routine, const Routine run_reference,
+ const std::vector<std::string> &options,
+ const GetMetric get_flops, const GetMetric get_bytes):
run_routine_(run_routine),
run_reference_(run_reference),
options_(options),
@@ -38,10 +38,10 @@ Client<T>::Client(const Routine run_routine, const Routine run_reference,
// Parses all arguments available for the CLBlast client testers. Some arguments might not be
// applicable, but are searched for anyway to be able to create one common argument parser. All
// arguments have a default value in case they are not found.
-template <typename T>
-Arguments<T> Client<T>::ParseArguments(int argc, char *argv[], const GetMetric default_a_ld,
- const GetMetric default_b_ld, const GetMetric default_c_ld) {
- auto args = Arguments<T>{};
+template <typename T, typename U>
+Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const GetMetric default_a_ld,
+ const GetMetric default_b_ld, const GetMetric default_c_ld) {
+ auto args = Arguments<U>{};
auto help = std::string{"Options given/available:\n"};
// These are the options which are not for every client: they are optional
@@ -75,8 +75,8 @@ Arguments<T> Client<T>::ParseArguments(int argc, char *argv[], const GetMetric d
if (o == kArgCOffset) { args.c_offset = GetArgument(argc, argv, help, kArgCOffset, size_t{0}); }
// Scalar values
- if (o == kArgAlpha) { args.alpha = GetArgument(argc, argv, help, kArgAlpha, GetScalar<T>()); }
- if (o == kArgBeta) { args.beta = GetArgument(argc, argv, help, kArgBeta, GetScalar<T>()); }
+ if (o == kArgAlpha) { args.alpha = GetArgument(argc, argv, help, kArgAlpha, GetScalar<U>()); }
+ if (o == kArgBeta) { args.beta = GetArgument(argc, argv, help, kArgBeta, GetScalar<U>()); }
}
// These are the options common to all routines
@@ -102,8 +102,8 @@ Arguments<T> Client<T>::ParseArguments(int argc, char *argv[], const GetMetric d
// =================================================================================================
// This is main performance tester
-template <typename T>
-void Client<T>::PerformanceTest(Arguments<T> &args, const SetMetric set_sizes) {
+template <typename T, typename U>
+void Client<T,U>::PerformanceTest(Arguments<U> &args, const SetMetric set_sizes) {
// Prints the header of the output table
PrintTableHeader(args.silent, options_);
@@ -174,10 +174,10 @@ void Client<T>::PerformanceTest(Arguments<T> &args, const SetMetric set_sizes) {
// Creates a vector of timing results, filled with execution times of the 'main computation'. The
// timing is performed using the milliseconds chrono functions. The function returns the minimum
// value found in the vector of timing results. The return value is in milliseconds.
-template <typename T>
-double Client<T>::TimedExecution(const size_t num_runs, const Arguments<T> &args,
- const Buffers &buffers, CommandQueue &queue,
- Routine run_blas, const std::string &library_name) {
+template <typename T, typename U>
+double Client<T,U>::TimedExecution(const size_t num_runs, const Arguments<U> &args,
+ const Buffers &buffers, CommandQueue &queue,
+ Routine run_blas, const std::string &library_name) {
auto timings = std::vector<double>(num_runs);
for (auto &timing: timings) {
auto start_time = std::chrono::steady_clock::now();
@@ -198,8 +198,8 @@ double Client<T>::TimedExecution(const size_t num_runs, const Arguments<T> &args
// =================================================================================================
// Prints the header of the performance table
-template <typename T>
-void Client<T>::PrintTableHeader(const bool silent, const std::vector<std::string> &args) {
+template <typename T, typename U>
+void Client<T,U>::PrintTableHeader(const bool silent, const std::vector<std::string> &args) {
if (!silent) {
for (auto i=size_t{0}; i<args.size(); ++i) { fprintf(stdout, "%9s ", ""); }
fprintf(stdout, " | <-- CLBlast --> | <-- clBLAS --> |\n");
@@ -210,9 +210,9 @@ void Client<T>::PrintTableHeader(const bool silent, const std::vector<std::strin
}
// Print a performance-result row
-template <typename T>
-void Client<T>::PrintTableRow(const Arguments<T>& args, const double ms_clblast,
- const double ms_clblas) {
+template <typename T, typename U>
+void Client<T,U>::PrintTableRow(const Arguments<U>& args, const double ms_clblast,
+ const double ms_clblas) {
// Creates a vector of relevant variables
auto integers = std::vector<size_t>{};
@@ -276,10 +276,12 @@ void Client<T>::PrintTableRow(const Arguments<T>& args, const double ms_clblast,
// =================================================================================================
// Compiles the templated class
-template class Client<float>;
-template class Client<double>;
-template class Client<float2>;
-template class Client<double2>;
+template class Client<float,float>;
+template class Client<double,double>;
+template class Client<float2,float2>;
+template class Client<double2,double2>;
+template class Client<float2,float>;
+template class Client<double2,double>;
// =================================================================================================
} // namespace clblast
diff --git a/test/performance/client.h b/test/performance/client.h
index f9f219d0..c9095967 100644
--- a/test/performance/client.h
+++ b/test/performance/client.h
@@ -10,6 +10,8 @@
// This class implements the performance-test client. It is generic for all CLBlast routines by
// taking a number of routine-specific functions as arguments, such as how to compute buffer sizes
// or how to get the FLOPS count.
+// Typename T: the data-type of the routine's memory buffers (==precision)
+// Typename U: the data-type of the alpha and beta arguments
//
// This file also provides the common interface to the performance client (see the 'RunClient'
// function for details).
@@ -32,7 +34,7 @@ namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
-template <typename T>
+template <typename T, typename U>
class Client {
public:
@@ -40,9 +42,9 @@ class Client {
const cl_device_type kDeviceType = CL_DEVICE_TYPE_ALL;
// Shorthand for the routine-specific functions passed to the tester
- using Routine = std::function<StatusCode(const Arguments<T>&, const Buffers&, CommandQueue&)>;
- using SetMetric = std::function<void(Arguments<T>&)>;
- using GetMetric = std::function<size_t(const Arguments<T>&)>;
+ using Routine = std::function<StatusCode(const Arguments<U>&, const Buffers&, CommandQueue&)>;
+ using SetMetric = std::function<void(Arguments<U>&)>;
+ using GetMetric = std::function<size_t(const Arguments<U>&)>;
// The constructor
Client(const Routine run_routine, const Routine run_reference,
@@ -51,24 +53,24 @@ class Client {
// Parses all command-line arguments, filling in the arguments structure. If no command-line
// argument is given for a particular argument, it is filled in with a default value.
- Arguments<T> ParseArguments(int argc, char *argv[], const GetMetric default_a_ld,
+ Arguments<U> ParseArguments(int argc, char *argv[], const GetMetric default_a_ld,
const GetMetric default_b_ld, const GetMetric default_c_ld);
// The main client function, setting-up arguments, matrices, OpenCL buffers, etc. After set-up, it
// calls the client routines.
- void PerformanceTest(Arguments<T> &args, const SetMetric set_sizes);
+ void PerformanceTest(Arguments<U> &args, const SetMetric set_sizes);
private:
// Runs a function a given number of times and returns the execution time of the shortest instance
- double TimedExecution(const size_t num_runs, const Arguments<T> &args, const Buffers &buffers,
+ double TimedExecution(const size_t num_runs, const Arguments<U> &args, const Buffers &buffers,
CommandQueue &queue, Routine run_blas, const std::string &library_name);
// Prints the header of a performance-data table
void PrintTableHeader(const bool silent, const std::vector<std::string> &args);
// Prints a row of performance data, including results of two libraries
- void PrintTableRow(const Arguments<T>& args, const double ms_clblast, const double ms_clblas);
+ void PrintTableRow(const Arguments<U>& args, const double ms_clblast, const double ms_clblas);
// The routine-specific functions passed to the tester
const Routine run_routine_;
@@ -82,12 +84,12 @@ class Client {
// The interface to the performance client. This is a separate function in the header such that it
// is automatically compiled for each routine, templated by the parameter "C".
-template <typename C, typename T>
+template <typename C, typename T, typename U>
void RunClient(int argc, char *argv[]) {
// Creates a new client
- auto client = Client<T>(C::RunRoutine, C::RunReference, C::GetOptions(),
- C::GetFlops, C::GetBytes);
+ auto client = Client<T,U>(C::RunRoutine, C::RunReference, C::GetOptions(),
+ C::GetFlops, C::GetBytes);
// Simple command line argument parser with defaults
auto args = client.ParseArguments(argc, argv, C::DefaultLDA, C::DefaultLDB, C::DefaultLDC);
diff --git a/test/performance/routines/xaxpy.cc b/test/performance/routines/xaxpy.cc
index 3ced80ed..6a2b96c1 100644
--- a/test/performance/routines/xaxpy.cc
+++ b/test/performance/routines/xaxpy.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXaxpy<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXaxpy<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXaxpy<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXaxpy<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXaxpy<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXaxpy<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXaxpy<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXaxpy<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/performance/routines/xgemm.cc b/test/performance/routines/xgemm.cc
index 36c74b9a..9a02e595 100644
--- a/test/performance/routines/xgemm.cc
+++ b/test/performance/routines/xgemm.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXgemm<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXgemm<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXgemm<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXgemm<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/performance/routines/xgemv.cc b/test/performance/routines/xgemv.cc
index 183dd4a1..6f69ef21 100644
--- a/test/performance/routines/xgemv.cc
+++ b/test/performance/routines/xgemv.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXgemv<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemv<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXgemv<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemv<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXgemv<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemv<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXgemv<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemv<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/performance/routines/xsymm.cc b/test/performance/routines/xsymm.cc
index 0c7f5e1e..8738ceda 100644
--- a/test/performance/routines/xsymm.cc
+++ b/test/performance/routines/xsymm.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXsymm<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsymm<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXsymm<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsymm<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXsymm<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsymm<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXsymm<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsymm<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/performance/routines/xsyr2k.cc b/test/performance/routines/xsyr2k.cc
index 63b50df6..e4c76229 100644
--- a/test/performance/routines/xsyr2k.cc
+++ b/test/performance/routines/xsyr2k.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXsyr2k<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyr2k<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXsyr2k<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyr2k<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXsyr2k<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyr2k<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXsyr2k<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyr2k<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/performance/routines/xsyrk.cc b/test/performance/routines/xsyrk.cc
index 9022d4f8..53fecb69 100644
--- a/test/performance/routines/xsyrk.cc
+++ b/test/performance/routines/xsyrk.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXsyrk<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyrk<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXsyrk<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyrk<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXsyrk<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyrk<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXsyrk<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXsyrk<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/performance/routines/xtrmm.cc b/test/performance/routines/xtrmm.cc
index 91dcbd07..2ab9ce77 100644
--- a/test/performance/routines/xtrmm.cc
+++ b/test/performance/routines/xtrmm.cc
@@ -16,19 +16,23 @@
// =================================================================================================
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf:
throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXtrmm<float>, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXtrmm<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXtrmm<double>, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXtrmm<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXtrmm<clblast::float2>, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXtrmm<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXtrmm<clblast::double2>, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXtrmm<double2>, double2, double2>(argc, argv); break;
}
return 0;
}