summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/correctness/tester.cpp5
-rw-r--r--test/performance/client.cpp15
-rw-r--r--test/performance/client.hpp3
3 files changed, 20 insertions, 3 deletions
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp
index 92e2c1b8..362c5c2c 100644
--- a/test/correctness/tester.cpp
+++ b/test/correctness/tester.cpp
@@ -15,6 +15,7 @@
#include <vector>
#include <iostream>
#include <cmath>
+#include <cstdlib>
#include "test/correctness/tester.hpp"
@@ -27,8 +28,8 @@ template <typename T, typename U>
Tester<T,U>::Tester(int argc, char *argv[], const bool silent,
const std::string &name, const std::vector<std::string> &options):
help_("Options given/available:\n"),
- platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, size_t{0}))),
- device_(Device(platform_, GetArgument(argc, argv, help_, kArgDevice, size_t{0}))),
+ platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, ConvertArgument(std::getenv("CLBLAST_PLATFORM"), size_t{0})))),
+ device_(Device(platform_, GetArgument(argc, argv, help_, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0})))),
context_(Context(device_)),
queue_(Queue(context_, device_)),
full_test_(CheckArgument(argc, argv, help_, kArgFullTest)),
diff --git a/test/performance/client.cpp b/test/performance/client.cpp
index d0068f8b..aaaab22e 100644
--- a/test/performance/client.cpp
+++ b/test/performance/client.cpp
@@ -113,6 +113,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le
args.print_help = CheckArgument(argc, argv, help, kArgHelp);
args.silent = CheckArgument(argc, argv, help, kArgQuiet);
args.no_abbrv = CheckArgument(argc, argv, help, kArgNoAbbreviations);
+ warm_up_ = CheckArgument(argc, argv, help, kArgWarmUp);
// Prints the chosen (or defaulted) arguments to screen. This also serves as the help message,
// which is thus always displayed (unless silence is specified).
@@ -244,12 +245,24 @@ template <typename T, typename U>
double Client<T,U>::TimedExecution(const size_t num_runs, const Arguments<U> &args,
Buffers<T> &buffers, Queue &queue,
Routine run_blas, const std::string &library_name) {
+ auto status = StatusCode::kSuccess;
+
+ // Do an optional warm-up to omit compilation times and initialisations from the measurements
+ if (warm_up_) {
+ try {
+ status = run_blas(args, buffers, queue);
+ } catch (...) { status = static_cast<StatusCode>(kUnknownError); }
+ if (status != StatusCode::kSuccess) {
+ throw std::runtime_error(library_name+" error: "+ToString(static_cast<int>(status)));
+ }
+ }
+
+ // Start the timed part
auto timings = std::vector<double>(num_runs);
for (auto &timing: timings) {
auto start_time = std::chrono::steady_clock::now();
// Executes the main computation
- auto status = StatusCode::kSuccess;
try {
status = run_blas(args, buffers, queue);
} catch (...) { status = static_cast<StatusCode>(kUnknownError); }
diff --git a/test/performance/client.hpp b/test/performance/client.hpp
index 5ff2aec7..6d35fced 100644
--- a/test/performance/client.hpp
+++ b/test/performance/client.hpp
@@ -82,6 +82,9 @@ class Client {
const std::vector<std::string> options_;
const GetMetric get_flops_;
const GetMetric get_bytes_;
+
+ // Extra arguments
+ bool warm_up_; // if enabled, do a warm-up run first before measuring execution time
};
// =================================================================================================