From 58642131623733ed7360fa146d106cff7f3a057c Mon Sep 17 00:00:00 2001 From: Arnur Nigmetov Date: Fri, 1 Jun 2018 16:01:00 +0200 Subject: First trivial test for point cloud version --- geom_matching/wasserstein/CMakeLists.txt | 2 +- .../wasserstein/include/auction_runner_gs.hpp | 1 + .../wasserstein/include/auction_runner_jac.hpp | 1 + geom_matching/wasserstein/include/hera_infinity.h | 2 +- .../wasserstein/tests/test_hera_wasserstein.cpp | 53 +--------- .../tests/test_hera_wasserstein_pure_geom.cpp | 111 +++++++++++++++++++++ geom_matching/wasserstein/tests/tests_reader.h | 67 +++++++++++++ 7 files changed, 184 insertions(+), 53 deletions(-) create mode 100644 geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp create mode 100644 geom_matching/wasserstein/tests/tests_reader.h diff --git a/geom_matching/wasserstein/CMakeLists.txt b/geom_matching/wasserstein/CMakeLists.txt index c6fba2c..dea4550 100644 --- a/geom_matching/wasserstein/CMakeLists.txt +++ b/geom_matching/wasserstein/CMakeLists.txt @@ -57,6 +57,6 @@ add_executable(wasserstein_dist_point_cloud ${CMAKE_CURRENT_SOURCE_DIR}/example/ target_link_libraries(wasserstein_dist_point_cloud PUBLIC ${libraries}) # Tests -add_executable(wasserstein_test ${CMAKE_CURRENT_SOURCE_DIR}/tests/tests_main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp include/hera_infinity.h) +add_executable(wasserstein_test ${CMAKE_CURRENT_SOURCE_DIR}/tests/tests_main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein_pure_geom.cpp include/hera_infinity.h tests/tests_reader.h) #add_executable(wasserstein_test EXCLUDE_FROM_ALL ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_hera_wasserstein.cpp) target_link_libraries(wasserstein_test PUBLIC ${libraries}) diff --git a/geom_matching/wasserstein/include/auction_runner_gs.hpp b/geom_matching/wasserstein/include/auction_runner_gs.hpp index 960c707..141cb2c 100644 --- a/geom_matching/wasserstein/include/auction_runner_gs.hpp +++ b/geom_matching/wasserstein/include/auction_runner_gs.hpp @@ -287,6 +287,7 @@ void AuctionRunnerGS::run_auction() if (num_bidders == 1) { assign_item_to_bidder(0, 0); wasserstein_cost = get_item_bidder_cost(0,0); + is_distance_computed = true; return; } diff --git a/geom_matching/wasserstein/include/auction_runner_jac.hpp b/geom_matching/wasserstein/include/auction_runner_jac.hpp index c519de1..e623f4a 100644 --- a/geom_matching/wasserstein/include/auction_runner_jac.hpp +++ b/geom_matching/wasserstein/include/auction_runner_jac.hpp @@ -559,6 +559,7 @@ namespace ws { if (num_bidders == 1) { assign_item_to_bidder(0, 0); wasserstein_cost = get_item_bidder_cost(0,0); + is_distance_computed = true; return; } R init_eps = (initial_epsilon > 0.0) ? initial_epsilon : oracle.max_val_ / 4.0; diff --git a/geom_matching/wasserstein/include/hera_infinity.h b/geom_matching/wasserstein/include/hera_infinity.h index 5a446e7..8d86dbb 100644 --- a/geom_matching/wasserstein/include/hera_infinity.h +++ b/geom_matching/wasserstein/include/hera_infinity.h @@ -13,7 +13,7 @@ namespace hera { }; template - inline Real get_infinity() + inline constexpr Real get_infinity() { return Real(-1); } diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp index 3d5db5f..0a80d2f 100644 --- a/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp +++ b/geom_matching/wasserstein/tests/test_hera_wasserstein.cpp @@ -8,61 +8,12 @@ #undef LOG_AUCTION #include "wasserstein.h" +#include "tests_reader.h" +using namespace hera_test; using PairVector = std::vector>; -std::vector split_on_delim(const std::string& s, char delim) -{ - std::stringstream ss(s); - std::string token; - std::vector tokens; - while(std::getline(ss, token, delim)) { - tokens.push_back(token); - } - return tokens; -} - - -// single row in a file with test cases -struct TestFromFileCase { - - std::string file_1; - std::string file_2; - double q; - double internal_p; - double answer; - - TestFromFileCase(std::string s) - { - auto tokens = split_on_delim(s, ' '); - assert(tokens.size() == 5); - - file_1 = tokens.at(0); - file_2 = tokens.at(1); - q = std::stod(tokens.at(2)); - internal_p = std::stod(tokens.at(3)); - answer = std::stod(tokens.at(4)); - - if ( q < 1.0 or std::isinf(q) or - (internal_p != hera::get_infinity() and internal_p < 1.0)) { - throw std::runtime_error("Bad line in test_list.txt"); - } - } -}; - -std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s) -{ - out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = "; - if (s.internal_p != hera::get_infinity()) { - out << s.internal_p; - } else { - out << "infinity"; - } - out << ", answer = " << s.answer << "]"; - return out; -} - TEST_CASE("simple cases", "wasserstein_dist") { diff --git a/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp b/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp new file mode 100644 index 0000000..9603ceb --- /dev/null +++ b/geom_matching/wasserstein/tests/test_hera_wasserstein_pure_geom.cpp @@ -0,0 +1,111 @@ +#include "catch/catch.hpp" + +#include +#include + + +#undef LOG_AUCTION + +#include "wasserstein_pure_geom.hpp" +#include "tests_reader.h" + +using namespace hera_test; + +TEST_CASE("simple point clouds", "wasserstein_dist_pure_geom") +{ +// int n_points = 3; +// int dim = 3; +// using Traits = hera::ws::dnn::DynamicPointTraits; +// hera::ws::dnn::DynamicPointTraits traits(dim); +// hera::ws::dnn::DynamicPointVector dgm_a = traits.container(n_points); +// hera::ws::dnn::DynamicPointVector dgm_b = traits.container(n_points); +// +// dgm_a[0][0] = 0.0; +// dgm_a[0][1] = 0.0; +// dgm_a[0][2] = 0.0; +// +// dgm_a[1][0] = 1.0; +// dgm_a[1][1] = 0.0; +// dgm_a[1][2] = 0.0; +// +// dgm_a[2][0] = 0.0; +// dgm_a[2][1] = 1.0; +// dgm_a[2][2] = 1.0; +// +// dgm_b[0][0] = 0.0; +// dgm_b[0][1] = 0.1; +// dgm_b[0][2] = 0.1; +// +// dgm_b[1][0] = 1.1; +// dgm_b[1][1] = 0.0; +// dgm_b[1][2] = 0.0; +// +// dgm_b[2][0] = 0.0; +// dgm_b[2][1] = 1.1; +// dgm_b[2][2] = 0.9; + + const int dim = 3; + using Traits = hera::ws::dnn::DynamicPointTraits; + hera::ws::dnn::DynamicPointTraits traits(dim); + hera::AuctionParams params; + params.dim = dim; + params.wasserstein_power = 1.0; + params.delta = 0.01; + params.internal_p = hera::get_infinity(); + params.initial_epsilon = 0.0; + params.epsilon_common_ratio = 0.0; + params.max_num_phases = 30; + params.gamma_threshold = 0.0; + params.max_bids_per_round = 0; // use Jacobi + + + SECTION("trivial: two single-point diagrams-1") { + + int n_points = 1; + hera::ws::dnn::DynamicPointVector dgm_a = traits.container(n_points); + hera::ws::dnn::DynamicPointVector dgm_b = traits.container(n_points); + + dgm_a[0][0] = 0.0; + dgm_a[0][1] = 0.0; + dgm_a[0][2] = 0.0; + + dgm_b[0][0] = 1.0; + dgm_b[0][1] = 1.0; + dgm_b[0][2] = 1.0; + + std::vector max_bids { 1, 10, 0 }; + std::vector internal_ps{ 1, 2, static_cast(hera::get_infinity()) }; + std::vector wasserstein_powers { 1, 2, 3 }; + + for(auto internal_p : internal_ps) { + // there is only one point, so the answer does not depend wasserstein power + double correct_answer; + switch (internal_p) { + case 1 : + correct_answer = 3.0; + break; + case 2 : + correct_answer = sqrt(3.0); + break; + case static_cast(hera::get_infinity()) : + correct_answer = 1.0; + break; + default : + throw std::runtime_error("Correct answer not specified in test case"); + } + + for (auto max_bid : max_bids) { + for (auto wasserstein_power : wasserstein_powers) { + params.max_bids_per_round = max_bid; + params.internal_p = internal_p; + params.wasserstein_power = wasserstein_power; + double d1 = hera::ws::wasserstein_dist(dgm_a, dgm_b, params); + double d2 = hera::ws::wasserstein_dist(dgm_b, dgm_a, params); + REQUIRE(fabs(d1 - d2) <= 0.00000000001); + REQUIRE(fabs(d1 - correct_answer) <= 0.00000000001); + } + } + } + } +} + diff --git a/geom_matching/wasserstein/tests/tests_reader.h b/geom_matching/wasserstein/tests/tests_reader.h new file mode 100644 index 0000000..f2d5735 --- /dev/null +++ b/geom_matching/wasserstein/tests/tests_reader.h @@ -0,0 +1,67 @@ +#ifndef WASSERSTEIN_TESTS_READER_H +#define WASSERSTEIN_TESTS_READER_H + +#include +#include +#include +#include +#include +#include +#include + +#include "hera_infinity.h" + +namespace hera_test { + inline std::vector split_on_delim(const std::string& s, char delim) + { + std::stringstream ss(s); + std::string token; + std::vector tokens; + while (std::getline(ss, token, delim)) { + tokens.push_back(token); + } + return tokens; + } + + + // single row in a file with test cases + struct TestFromFileCase + { + + std::string file_1; + std::string file_2; + double q; + double internal_p; + double answer; + + TestFromFileCase(std::string s) + { + auto tokens = split_on_delim(s, ' '); + assert(tokens.size() == 5); + + file_1 = tokens.at(0); + file_2 = tokens.at(1); + q = std::stod(tokens.at(2)); + internal_p = std::stod(tokens.at(3)); + answer = std::stod(tokens.at(4)); + + if (q < 1.0 or std::isinf(q) or + (internal_p != hera::get_infinity() and internal_p < 1.0)) { + throw std::runtime_error("Bad line in test_list.txt"); + } + } + }; + + inline std::ostream& operator<<(std::ostream& out, const TestFromFileCase& s) + { + out << "[" << s.file_1 << ", " << s.file_2 << ", q = " << s.q << ", norm = "; + if (s.internal_p != hera::get_infinity()) { + out << s.internal_p; + } else { + out << "infinity"; + } + out << ", answer = " << s.answer << "]"; + return out; + } +} // namespace hera_test +#endif //WASSERSTEIN_TESTS_READER_H -- cgit v1.2.3