diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/correctness/tester.cpp | 5 | ||||
-rw-r--r-- | test/performance/client.cpp | 15 | ||||
-rw-r--r-- | test/performance/client.hpp | 3 |
3 files changed, 20 insertions, 3 deletions
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index 92e2c1b8..362c5c2c 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -15,6 +15,7 @@ #include <vector> #include <iostream> #include <cmath> +#include <cstdlib> #include "test/correctness/tester.hpp" @@ -27,8 +28,8 @@ template <typename T, typename U> Tester<T,U>::Tester(int argc, char *argv[], const bool silent, const std::string &name, const std::vector<std::string> &options): help_("Options given/available:\n"), - platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, size_t{0}))), - device_(Device(platform_, GetArgument(argc, argv, help_, kArgDevice, size_t{0}))), + platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, ConvertArgument(std::getenv("CLBLAST_PLATFORM"), size_t{0})))), + device_(Device(platform_, GetArgument(argc, argv, help_, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0})))), context_(Context(device_)), queue_(Queue(context_, device_)), full_test_(CheckArgument(argc, argv, help_, kArgFullTest)), diff --git a/test/performance/client.cpp b/test/performance/client.cpp index d0068f8b..aaaab22e 100644 --- a/test/performance/client.cpp +++ b/test/performance/client.cpp @@ -113,6 +113,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le args.print_help = CheckArgument(argc, argv, help, kArgHelp); args.silent = CheckArgument(argc, argv, help, kArgQuiet); args.no_abbrv = CheckArgument(argc, argv, help, kArgNoAbbreviations); + warm_up_ = CheckArgument(argc, argv, help, kArgWarmUp); // Prints the chosen (or defaulted) arguments to screen. This also serves as the help message, // which is thus always displayed (unless silence is specified). @@ -244,12 +245,24 @@ template <typename T, typename U> double Client<T,U>::TimedExecution(const size_t num_runs, const Arguments<U> &args, Buffers<T> &buffers, Queue &queue, Routine run_blas, const std::string &library_name) { + auto status = StatusCode::kSuccess; + + // Do an optional warm-up to omit compilation times and initialisations from the measurements + if (warm_up_) { + try { + status = run_blas(args, buffers, queue); + } catch (...) { status = static_cast<StatusCode>(kUnknownError); } + if (status != StatusCode::kSuccess) { + throw std::runtime_error(library_name+" error: "+ToString(static_cast<int>(status))); + } + } + + // Start the timed part auto timings = std::vector<double>(num_runs); for (auto &timing: timings) { auto start_time = std::chrono::steady_clock::now(); // Executes the main computation - auto status = StatusCode::kSuccess; try { status = run_blas(args, buffers, queue); } catch (...) { status = static_cast<StatusCode>(kUnknownError); } diff --git a/test/performance/client.hpp b/test/performance/client.hpp index 5ff2aec7..6d35fced 100644 --- a/test/performance/client.hpp +++ b/test/performance/client.hpp @@ -82,6 +82,9 @@ class Client { const std::vector<std::string> options_; const GetMetric get_flops_; const GetMetric get_bytes_; + + // Extra arguments + bool warm_up_; // if enabled, do a warm-up run first before measuring execution time }; // ================================================================================================= |