From 82469fc76432f37d955d87b05bf80b02026d19ff Mon Sep 17 00:00:00 2001 From: CNugteren Date: Wed, 8 Jul 2015 07:21:44 +0200 Subject: The testers now distinguish between the memory and alpha/beta data-type --- test/correctness/tester.cc | 217 ++++++++++++++++++++++++--------------------- 1 file changed, 115 insertions(+), 102 deletions(-) (limited to 'test/correctness/tester.cc') diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc index db4ee619..378968ed 100644 --- a/test/correctness/tester.cc +++ b/test/correctness/tester.cc @@ -23,9 +23,9 @@ namespace clblast { // General constructor for all CLBlast testers. It prints out the test header to stdout and sets-up // the clBLAS library for reference. -template -Tester::Tester(int argc, char *argv[], const bool silent, - const std::string &name, const std::vector &options): +template +Tester::Tester(int argc, char *argv[], const bool silent, + const std::string &name, const std::vector &options): help_("Options given/available:\n"), platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, size_t{0}))), device_(Device(platform_, kDeviceType, GetArgument(argc, argv, help_, kArgDevice, size_t{0}))), @@ -51,7 +51,7 @@ Tester::Tester(int argc, char *argv[], const bool silent, kPrintMessage.c_str(), name.c_str(), kPrintEnd.c_str()); // Checks whether the precision is supported - if (!PrecisionSupported()) { + if (!PrecisionSupported(device_)) { fprintf(stdout, "\n* All tests skipped: %sUnsupported precision%s\n", kPrintWarning.c_str(), kPrintEnd.c_str()); return; @@ -76,9 +76,9 @@ Tester::Tester(int argc, char *argv[], const bool silent, } // Destructor prints the summary of the test cases and cleans-up the clBLAS library -template -Tester::~Tester() { - if (PrecisionSupported()) { +template +Tester::~Tester() { + if (PrecisionSupported(device_)) { fprintf(stdout, "* Completed all test-cases for this routine. Results:\n"); fprintf(stdout, " %lu test(s) passed\n", tests_passed_); if (tests_skipped_ > 0) { fprintf(stdout, "%s", kPrintWarning.c_str()); } @@ -94,8 +94,8 @@ Tester::~Tester() { // Function called at the start of each test. This prints a header with information about the // test and re-initializes all test data-structures. -template -void Tester::TestStart(const std::string &test_name, const std::string &test_configuration) { +template +void Tester::TestStart(const std::string &test_name, const std::string &test_configuration) { // Prints the header fprintf(stdout, "* Testing %s'%s'%s for %s'%s'%s:\n", @@ -113,8 +113,8 @@ void Tester::TestStart(const std::string &test_name, const std::string &test_ // Function called at the end of each test. This prints errors if any occured. It also prints a // summary of the number of sub-tests passed/failed. -template -void Tester::TestEnd() { +template +void Tester::TestEnd() { fprintf(stdout, "\n"); tests_passed_ += num_passed_; tests_failed_ += num_skipped_; @@ -172,45 +172,9 @@ void Tester::TestEnd() { // ================================================================================================= -// Compares two floating point values and returns whether they are within an acceptable error -// margin. This replaces GTest's EXPECT_NEAR(). -template -bool Tester::TestSimilarity(const T val1, const T val2) { - const auto difference = std::fabs(val1 - val2); - - // Shortcut, handles infinities - if (val1 == val2) { - return true; - } - // The values are zero or very small: the relative error is less meaningful - else if (val1 == 0 || val2 == 0 || difference < static_cast(kErrorMarginAbsolute)) { - return (difference < static_cast(kErrorMarginAbsolute)); - } - // Use relative error - else { - return (difference / (std::fabs(val1)+std::fabs(val2))) < static_cast(kErrorMarginRelative); - } -} - -// Specialisations for complex data-types -template <> -bool Tester::TestSimilarity(const float2 val1, const float2 val2) { - auto real = Tester::TestSimilarity(val1.real(), val2.real()); - auto imag = Tester::TestSimilarity(val1.imag(), val2.imag()); - return (real && imag); -} -template <> -bool Tester::TestSimilarity(const double2 val1, const double2 val2) { - auto real = Tester::TestSimilarity(val1.real(), val2.real()); - auto imag = Tester::TestSimilarity(val1.imag(), val2.imag()); - return (real && imag); -} - -// ================================================================================================= - // Handles a 'pass' or 'error' depending on whether there are any errors -template -void Tester::TestErrorCount(const size_t errors, const size_t size, const Arguments &args) { +template +void Tester::TestErrorCount(const size_t errors, const size_t size, const Arguments &args) { // Finished successfully if (errors == 0) { @@ -228,9 +192,9 @@ void Tester::TestErrorCount(const size_t errors, const size_t size, const Arg // Compares two status codes for equality. The outcome can be a pass (they are the same), a warning // (CLBlast reported a compilation error), or an error (they are different). -template -void Tester::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status, - const Arguments &args) { +template +void Tester::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status, + const Arguments &args) { // Finished successfully if (clblas_status == clblast_status) { @@ -261,62 +225,26 @@ void Tester::TestErrorCodes(const StatusCode clblas_status, const StatusCode // ================================================================================================= -// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various -// routines. This function is specialised for the different data-types. -template <> -const std::vector Tester::GetExampleScalars() { - if (full_test_) { return {0.0f, 1.0f, 3.14f}; } - else { return {3.14f}; } -} -template <> -const std::vector Tester::GetExampleScalars() { - if (full_test_) { return {0.0, 1.0, 3.14}; } - else { return {3.14}; } -} -template <> -const std::vector Tester::GetExampleScalars() { - if (full_test_) { return {{0.0f, 0.0f}, {1.0f, 1.3f}, {2.42f, 3.14f}}; } - else { return {{2.42f, 3.14f}}; } -} -template <> -const std::vector Tester::GetExampleScalars() { - if (full_test_) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; } - else { return {{2.42, 3.14}}; } -} - // Retrieves the offset values to test with -template -const std::vector Tester::GetOffsets() { +template +const std::vector Tester::GetOffsets() const { if (full_test_) { return {0, 10}; } else { return {0}; } } // ================================================================================================= -template <> bool Tester::PrecisionSupported() const { return true; } -template <> bool Tester::PrecisionSupported() const { return true; } -template <> bool Tester::PrecisionSupported() const { - auto extensions = device_.Extensions(); - return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; -} -template <> bool Tester::PrecisionSupported() const { - auto extensions = device_.Extensions(); - return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; -} - -// ================================================================================================= - // A test can either pass, be skipped, or fail -template -void Tester::ReportPass() { +template +void Tester::ReportPass() { num_passed_++; } -template -void Tester::ReportSkipped() { +template +void Tester::ReportSkipped() { num_skipped_++; } -template -void Tester::ReportError(const ErrorLogEntry &error_log_entry) { +template +void Tester::ReportError(const ErrorLogEntry &error_log_entry) { error_log_.push_back(error_log_entry); num_failed_++; } @@ -325,8 +253,8 @@ void Tester::ReportError(const ErrorLogEntry &error_log_entry) { // Prints the test-result symbol to screen. This function limits the maximum number of symbols per // line by printing newlines once every so many calls. -template -void Tester::PrintTestResult(const std::string &message) { +template +void Tester::PrintTestResult(const std::string &message) { if (print_count_ == kResultsPerLine) { print_count_ = 0; fprintf(stdout, "\n "); @@ -336,13 +264,98 @@ void Tester::PrintTestResult(const std::string &message) { print_count_++; } +// ================================================================================================= +// Below are the non-member functions (separated because of otherwise required partial class +// template specialization) +// ================================================================================================= + +// Compares two floating point values and returns whether they are within an acceptable error +// margin. This replaces GTest's EXPECT_NEAR(). +template +bool TestSimilarity(const T val1, const T val2) { + const auto difference = std::fabs(val1 - val2); + + // Set the allowed error margin for floating-point comparisons + constexpr auto kErrorMarginRelative = 1.0e-2; + constexpr auto kErrorMarginAbsolute = 1.0e-10; + + // Shortcut, handles infinities + if (val1 == val2) { + return true; + } + // The values are zero or very small: the relative error is less meaningful + else if (val1 == 0 || val2 == 0 || difference < static_cast(kErrorMarginAbsolute)) { + return (difference < static_cast(kErrorMarginAbsolute)); + } + // Use relative error + else { + const auto absolute_sum = std::fabs(val1) + std::fabs(val2); + return (difference / absolute_sum) < static_cast(kErrorMarginRelative); + } +} + +// Compiles the default case for non-complex data-types +template bool TestSimilarity(const float, const float); +template bool TestSimilarity(const double, const double); + +// Specialisations for complex data-types +template <> +bool TestSimilarity(const float2 val1, const float2 val2) { + auto real = TestSimilarity(val1.real(), val2.real()); + auto imag = TestSimilarity(val1.imag(), val2.imag()); + return (real && imag); +} +template <> +bool TestSimilarity(const double2 val1, const double2 val2) { + auto real = TestSimilarity(val1.real(), val2.real()); + auto imag = TestSimilarity(val1.imag(), val2.imag()); + return (real && imag); +} + +// ================================================================================================= + +// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various +// routines. This function is specialised for the different data-types. +template <> const std::vector GetExampleScalars(const bool full_test) { + if (full_test) { return {0.0f, 1.0f, 3.14f}; } + else { return {3.14f}; } +} +template <> const std::vector GetExampleScalars(const bool full_test) { + if (full_test) { return {0.0, 1.0, 3.14}; } + else { return {3.14}; } +} +template <> const std::vector GetExampleScalars(const bool full_test) { + if (full_test) { return {{0.0f, 0.0f}, {1.0f, 1.3f}, {2.42f, 3.14f}}; } + else { return {{2.42f, 3.14f}}; } +} +template <> const std::vector GetExampleScalars(const bool full_test) { + if (full_test) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; } + else { return {{2.42, 3.14}}; } +} + +// ================================================================================================= + +// Returns false is this precision is not supported by the device +template <> bool PrecisionSupported(const Device &) { return true; } +template <> bool PrecisionSupported(const Device &) { return true; } +template <> bool PrecisionSupported(const Device &device) { + auto extensions = device.Extensions(); + return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; +} +template <> bool PrecisionSupported(const Device &device) { + auto extensions = device.Extensions(); + return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; +} + // ================================================================================================= // Compiles the templated class -template class Tester; -template class Tester; -template class Tester; -template class Tester; +template class Tester; +template class Tester; +template class Tester; +template class Tester; +template class Tester; +template class Tester; // ================================================================================================= } // namespace clblast -- cgit v1.2.3