diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-07-08 07:21:44 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-07-08 07:21:44 +0200 |
commit | 82469fc76432f37d955d87b05bf80b02026d19ff (patch) | |
tree | 30d12a1c47ddcdf9ed59ee5bbddda2f019c2140c /test/correctness/tester.cc | |
parent | 599f9a70a6bb2388c27b7981276e4d39497f90fb (diff) |
The testers now distinguish between the memory and alpha/beta data-type
Diffstat (limited to 'test/correctness/tester.cc')
-rw-r--r-- | test/correctness/tester.cc | 217 |
1 files changed, 115 insertions, 102 deletions
diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc index db4ee619..378968ed 100644 --- a/test/correctness/tester.cc +++ b/test/correctness/tester.cc @@ -23,9 +23,9 @@ namespace clblast { // General constructor for all CLBlast testers. It prints out the test header to stdout and sets-up // the clBLAS library for reference. -template <typename T> -Tester<T>::Tester(int argc, char *argv[], const bool silent, - const std::string &name, const std::vector<std::string> &options): +template <typename T, typename U> +Tester<T,U>::Tester(int argc, char *argv[], const bool silent, + const std::string &name, const std::vector<std::string> &options): help_("Options given/available:\n"), platform_(Platform(GetArgument(argc, argv, help_, kArgPlatform, size_t{0}))), device_(Device(platform_, kDeviceType, GetArgument(argc, argv, help_, kArgDevice, size_t{0}))), @@ -51,7 +51,7 @@ Tester<T>::Tester(int argc, char *argv[], const bool silent, kPrintMessage.c_str(), name.c_str(), kPrintEnd.c_str()); // Checks whether the precision is supported - if (!PrecisionSupported()) { + if (!PrecisionSupported<T>(device_)) { fprintf(stdout, "\n* All tests skipped: %sUnsupported precision%s\n", kPrintWarning.c_str(), kPrintEnd.c_str()); return; @@ -76,9 +76,9 @@ Tester<T>::Tester(int argc, char *argv[], const bool silent, } // Destructor prints the summary of the test cases and cleans-up the clBLAS library -template <typename T> -Tester<T>::~Tester() { - if (PrecisionSupported()) { +template <typename T, typename U> +Tester<T,U>::~Tester() { + if (PrecisionSupported<T>(device_)) { fprintf(stdout, "* Completed all test-cases for this routine. Results:\n"); fprintf(stdout, " %lu test(s) passed\n", tests_passed_); if (tests_skipped_ > 0) { fprintf(stdout, "%s", kPrintWarning.c_str()); } @@ -94,8 +94,8 @@ Tester<T>::~Tester() { // Function called at the start of each test. This prints a header with information about the // test and re-initializes all test data-structures. -template <typename T> -void Tester<T>::TestStart(const std::string &test_name, const std::string &test_configuration) { +template <typename T, typename U> +void Tester<T,U>::TestStart(const std::string &test_name, const std::string &test_configuration) { // Prints the header fprintf(stdout, "* Testing %s'%s'%s for %s'%s'%s:\n", @@ -113,8 +113,8 @@ void Tester<T>::TestStart(const std::string &test_name, const std::string &test_ // Function called at the end of each test. This prints errors if any occured. It also prints a // summary of the number of sub-tests passed/failed. -template <typename T> -void Tester<T>::TestEnd() { +template <typename T, typename U> +void Tester<T,U>::TestEnd() { fprintf(stdout, "\n"); tests_passed_ += num_passed_; tests_failed_ += num_skipped_; @@ -172,45 +172,9 @@ void Tester<T>::TestEnd() { // ================================================================================================= -// Compares two floating point values and returns whether they are within an acceptable error -// margin. This replaces GTest's EXPECT_NEAR(). -template <typename T> -bool Tester<T>::TestSimilarity(const T val1, const T val2) { - const auto difference = std::fabs(val1 - val2); - - // Shortcut, handles infinities - if (val1 == val2) { - return true; - } - // The values are zero or very small: the relative error is less meaningful - else if (val1 == 0 || val2 == 0 || difference < static_cast<T>(kErrorMarginAbsolute)) { - return (difference < static_cast<T>(kErrorMarginAbsolute)); - } - // Use relative error - else { - return (difference / (std::fabs(val1)+std::fabs(val2))) < static_cast<T>(kErrorMarginRelative); - } -} - -// Specialisations for complex data-types -template <> -bool Tester<float2>::TestSimilarity(const float2 val1, const float2 val2) { - auto real = Tester<float>::TestSimilarity(val1.real(), val2.real()); - auto imag = Tester<float>::TestSimilarity(val1.imag(), val2.imag()); - return (real && imag); -} -template <> -bool Tester<double2>::TestSimilarity(const double2 val1, const double2 val2) { - auto real = Tester<double>::TestSimilarity(val1.real(), val2.real()); - auto imag = Tester<double>::TestSimilarity(val1.imag(), val2.imag()); - return (real && imag); -} - -// ================================================================================================= - // Handles a 'pass' or 'error' depending on whether there are any errors -template <typename T> -void Tester<T>::TestErrorCount(const size_t errors, const size_t size, const Arguments<T> &args) { +template <typename T, typename U> +void Tester<T,U>::TestErrorCount(const size_t errors, const size_t size, const Arguments<U> &args) { // Finished successfully if (errors == 0) { @@ -228,9 +192,9 @@ void Tester<T>::TestErrorCount(const size_t errors, const size_t size, const Arg // Compares two status codes for equality. The outcome can be a pass (they are the same), a warning // (CLBlast reported a compilation error), or an error (they are different). -template <typename T> -void Tester<T>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status, - const Arguments<T> &args) { +template <typename T, typename U> +void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status, + const Arguments<U> &args) { // Finished successfully if (clblas_status == clblast_status) { @@ -261,62 +225,26 @@ void Tester<T>::TestErrorCodes(const StatusCode clblas_status, const StatusCode // ================================================================================================= -// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various -// routines. This function is specialised for the different data-types. -template <> -const std::vector<float> Tester<float>::GetExampleScalars() { - if (full_test_) { return {0.0f, 1.0f, 3.14f}; } - else { return {3.14f}; } -} -template <> -const std::vector<double> Tester<double>::GetExampleScalars() { - if (full_test_) { return {0.0, 1.0, 3.14}; } - else { return {3.14}; } -} -template <> -const std::vector<float2> Tester<float2>::GetExampleScalars() { - if (full_test_) { return {{0.0f, 0.0f}, {1.0f, 1.3f}, {2.42f, 3.14f}}; } - else { return {{2.42f, 3.14f}}; } -} -template <> -const std::vector<double2> Tester<double2>::GetExampleScalars() { - if (full_test_) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; } - else { return {{2.42, 3.14}}; } -} - // Retrieves the offset values to test with -template <typename T> -const std::vector<size_t> Tester<T>::GetOffsets() { +template <typename T, typename U> +const std::vector<size_t> Tester<T,U>::GetOffsets() const { if (full_test_) { return {0, 10}; } else { return {0}; } } // ================================================================================================= -template <> bool Tester<float>::PrecisionSupported() const { return true; } -template <> bool Tester<float2>::PrecisionSupported() const { return true; } -template <> bool Tester<double>::PrecisionSupported() const { - auto extensions = device_.Extensions(); - return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; -} -template <> bool Tester<double2>::PrecisionSupported() const { - auto extensions = device_.Extensions(); - return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; -} - -// ================================================================================================= - // A test can either pass, be skipped, or fail -template <typename T> -void Tester<T>::ReportPass() { +template <typename T, typename U> +void Tester<T,U>::ReportPass() { num_passed_++; } -template <typename T> -void Tester<T>::ReportSkipped() { +template <typename T, typename U> +void Tester<T,U>::ReportSkipped() { num_skipped_++; } -template <typename T> -void Tester<T>::ReportError(const ErrorLogEntry &error_log_entry) { +template <typename T, typename U> +void Tester<T,U>::ReportError(const ErrorLogEntry &error_log_entry) { error_log_.push_back(error_log_entry); num_failed_++; } @@ -325,8 +253,8 @@ void Tester<T>::ReportError(const ErrorLogEntry &error_log_entry) { // Prints the test-result symbol to screen. This function limits the maximum number of symbols per // line by printing newlines once every so many calls. -template <typename T> -void Tester<T>::PrintTestResult(const std::string &message) { +template <typename T, typename U> +void Tester<T,U>::PrintTestResult(const std::string &message) { if (print_count_ == kResultsPerLine) { print_count_ = 0; fprintf(stdout, "\n "); @@ -337,12 +265,97 @@ void Tester<T>::PrintTestResult(const std::string &message) { } // ================================================================================================= +// Below are the non-member functions (separated because of otherwise required partial class +// template specialization) +// ================================================================================================= + +// Compares two floating point values and returns whether they are within an acceptable error +// margin. This replaces GTest's EXPECT_NEAR(). +template <typename T> +bool TestSimilarity(const T val1, const T val2) { + const auto difference = std::fabs(val1 - val2); + + // Set the allowed error margin for floating-point comparisons + constexpr auto kErrorMarginRelative = 1.0e-2; + constexpr auto kErrorMarginAbsolute = 1.0e-10; + + // Shortcut, handles infinities + if (val1 == val2) { + return true; + } + // The values are zero or very small: the relative error is less meaningful + else if (val1 == 0 || val2 == 0 || difference < static_cast<T>(kErrorMarginAbsolute)) { + return (difference < static_cast<T>(kErrorMarginAbsolute)); + } + // Use relative error + else { + const auto absolute_sum = std::fabs(val1) + std::fabs(val2); + return (difference / absolute_sum) < static_cast<T>(kErrorMarginRelative); + } +} + +// Compiles the default case for non-complex data-types +template bool TestSimilarity<float>(const float, const float); +template bool TestSimilarity<double>(const double, const double); + +// Specialisations for complex data-types +template <> +bool TestSimilarity(const float2 val1, const float2 val2) { + auto real = TestSimilarity(val1.real(), val2.real()); + auto imag = TestSimilarity(val1.imag(), val2.imag()); + return (real && imag); +} +template <> +bool TestSimilarity(const double2 val1, const double2 val2) { + auto real = TestSimilarity(val1.real(), val2.real()); + auto imag = TestSimilarity(val1.imag(), val2.imag()); + return (real && imag); +} + +// ================================================================================================= + +// Retrieves a list of example scalar values, used for the alpha and beta arguments for the various +// routines. This function is specialised for the different data-types. +template <> const std::vector<float> GetExampleScalars(const bool full_test) { + if (full_test) { return {0.0f, 1.0f, 3.14f}; } + else { return {3.14f}; } +} +template <> const std::vector<double> GetExampleScalars(const bool full_test) { + if (full_test) { return {0.0, 1.0, 3.14}; } + else { return {3.14}; } +} +template <> const std::vector<float2> GetExampleScalars(const bool full_test) { + if (full_test) { return {{0.0f, 0.0f}, {1.0f, 1.3f}, {2.42f, 3.14f}}; } + else { return {{2.42f, 3.14f}}; } +} +template <> const std::vector<double2> GetExampleScalars(const bool full_test) { + if (full_test) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; } + else { return {{2.42, 3.14}}; } +} + +// ================================================================================================= + +// Returns false is this precision is not supported by the device +template <> bool PrecisionSupported<float>(const Device &) { return true; } +template <> bool PrecisionSupported<float2>(const Device &) { return true; } +template <> bool PrecisionSupported<double>(const Device &device) { + auto extensions = device.Extensions(); + return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; +} +template <> bool PrecisionSupported<double2>(const Device &device) { + auto extensions = device.Extensions(); + return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; +} + +// ================================================================================================= // Compiles the templated class -template class Tester<float>; -template class Tester<double>; -template class Tester<float2>; -template class Tester<double2>; +template class Tester<float, float>; +template class Tester<double, double>; +template class Tester<float2, float2>; +template class Tester<double2, double2>; +template class Tester<float2, float>; +template class Tester<double2, double>; // ================================================================================================= } // namespace clblast |