diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | bottleneck/include/bottleneck_detail.hpp | 28 | ||||
-rw-r--r-- | bottleneck/tests/data/test_880_A | 4 | ||||
-rw-r--r-- | bottleneck/tests/data/test_880_B | 4 | ||||
-rw-r--r-- | bottleneck/tests/data/test_list.txt | 1 | ||||
-rw-r--r-- | bottleneck/tests/test_hera_bottleneck.cpp | 16 | ||||
-rw-r--r-- | matching/include/matching_distance.hpp | 1 | ||||
-rw-r--r-- | matching/include/persistence_module.h | 1 | ||||
-rw-r--r-- | matching/include/persistence_module.hpp | 2 | ||||
-rw-r--r-- | wasserstein/include/wasserstein.h | 70 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_1_A | 3 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_1_B | 4 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_2_A | 3 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_2_B | 5 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_3_A | 6 | ||||
-rw-r--r-- | wasserstein/tests/data/test_inf_3_B | 5 | ||||
-rw-r--r-- | wasserstein/tests/data/test_list.txt | 3 | ||||
-rw-r--r-- | wasserstein/tests/test_hera_wasserstein.cpp | 29 |
18 files changed, 139 insertions, 50 deletions
@@ -1,6 +1,6 @@ /*.cfg -/build -build/ +*build*/ +.idea/ *.gitattributes *.opensdf *.sdf diff --git a/bottleneck/include/bottleneck_detail.hpp b/bottleneck/include/bottleneck_detail.hpp index 8f51d07..f976909 100644 --- a/bottleneck/include/bottleneck_detail.hpp +++ b/bottleneck/include/bottleneck_detail.hpp @@ -51,7 +51,6 @@ namespace hera { void binarySearch(const Real epsilon, std::pair<Real, Real>& result, BoundMatchOracle <Real>& oracle, - const Real infinityCost, bool isResultInitializedCorrectly, const Real distProbeInit) { @@ -59,9 +58,6 @@ namespace hera { Real& distMin = result.first; Real& distMax = result.second; - distMin = std::max(distMin, infinityCost); - distMax = std::max(distMax, infinityCost); - Real distProbe; if (not isResultInitializedCorrectly) { @@ -87,13 +83,6 @@ namespace hera { // bounds are correct , perform binary search distProbe = (distMin + distMax) / 2.0; while ((distMax - distMin) / distMin >= epsilon) { - - if (distMax < infinityCost) { - distMin = infinityCost; - distMax = infinityCost; - break; - } - if (oracle.isMatchLess(distProbe)) { distMax = distProbe; } else { @@ -102,9 +91,6 @@ namespace hera { distProbe = (distMin + distMax) / 2.0; } - - distMin = std::max(distMin, infinityCost); - distMax = std::max(distMax, infinityCost); } // template<class Real> @@ -291,7 +277,7 @@ namespace hera { // get a 3-approximation of maximal distance between A and B // as a starting value for probe distance Real distProbe { getFurthestDistance3Approx<Real, DiagramPointSet<Real>>(A, B) }; - binarySearch(epsilon, result, oracle, infinity_cost, false, distProbe); + binarySearch(epsilon, result, oracle, false, distProbe); // to compute longest edge a perfect matching is needed if (compute_longest_edge and result.first > infinity_cost) { oracle.isMatchLess(result.second); @@ -448,8 +434,10 @@ namespace hera { Real bottleneckDistApprox(DiagramPointSet <Real>& A, DiagramPointSet <Real>& B, const Real epsilon, MatchingEdge <Real>& longest_edge, bool compute_longest_edge) { + // must compute here: infinity points will be erased in bottleneckDistApproxInterval + Real infCost = getInfinityCost(A, B).cost; auto interval = bottleneckDistApproxInterval<Real>(A, B, epsilon, longest_edge, compute_longest_edge); - return interval.second; + return std::max(infCost, interval.second); } @@ -518,6 +506,9 @@ namespace hera { using DgmPoint = DiagramPoint<Real>; constexpr Real epsilon = 0.001; + + Real infCost = getInfinityCost(A, B, true).cost; + auto interval = bottleneckDistApproxInterval(A, B, epsilon, longest_edge, true); // if the longest edge is on infinity, the answer is already exact // this will be detected here and all the code after if @@ -627,8 +618,9 @@ namespace hera { pw_dists.push_back(d); } - return bottleneckDistExactFromSortedPwDist(A, B, pw_dists, decPrecision, longest_edge, - compute_longest_edge); + Real exactFinite = bottleneckDistExactFromSortedPwDist(A, B, pw_dists, decPrecision, longest_edge, compute_longest_edge); + + return std::max(infCost, exactFinite); } } // end namespace bt diff --git a/bottleneck/tests/data/test_880_A b/bottleneck/tests/data/test_880_A new file mode 100644 index 0000000..5ad7454 --- /dev/null +++ b/bottleneck/tests/data/test_880_A @@ -0,0 +1,4 @@ +0.4217142857142855 inf +0.8868393268494345 1.8 +0.45290401176484574 1.8 +0.9154285714285715 1.8 diff --git a/bottleneck/tests/data/test_880_B b/bottleneck/tests/data/test_880_B new file mode 100644 index 0000000..22129a2 --- /dev/null +++ b/bottleneck/tests/data/test_880_B @@ -0,0 +1,4 @@ +0.8742857142857143 1.8 +0.46285714285714286 inf +0.4722695736819835 1.8 +0.8474648251910301 1.8 diff --git a/bottleneck/tests/data/test_list.txt b/bottleneck/tests/data/test_list.txt index 24b462a..91c03a4 100644 --- a/bottleneck/tests/data/test_list.txt +++ b/bottleneck/tests/data/test_list.txt @@ -790,3 +790,4 @@ test_876_A test_876_B 0.01 165 test_877_A test_877_B 0.01 190 test_878_A test_878_B 0.01 242 test_879_A test_879_B 0.01 187 +test_880_A test_880_B 0.0 0.0411428571428574 diff --git a/bottleneck/tests/test_hera_bottleneck.cpp b/bottleneck/tests/test_hera_bottleneck.cpp index f22e415..c74e800 100644 --- a/bottleneck/tests/test_hera_bottleneck.cpp +++ b/bottleneck/tests/test_hera_bottleneck.cpp @@ -159,6 +159,9 @@ TEST_CASE("infinity points", "bottleneckDistApprox") double d = hera::bottleneckDistApprox<>(diagram_A, diagram_B, delta); double corr_answer = 1.0; REQUIRE( fabs(d - corr_answer) <= delta * corr_answer); + + double exact_d = hera::bottleneckDistExact<>(diagram_A, diagram_B); + REQUIRE(exact_d == corr_answer); } SECTION("two points at infinity") { @@ -559,15 +562,23 @@ TEST_CASE("file cases", "bottleneck_dist") SECTION("from file:") { - for (const auto& ts : test_params) { + for (auto& ts : test_params) { bool read_file_A = hera::readDiagramPointSet(dir_prefix + ts.file_1, diagram_A); bool read_file_B = hera::readDiagramPointSet(dir_prefix + ts.file_2, diagram_B); REQUIRE(read_file_A); REQUIRE(read_file_B); - double hera_answer = hera::bottleneckDistApprox(diagram_A, diagram_B, ts.delta, longest_edge, true); + double hera_answer; + if (ts.delta > 0) + hera_answer = hera::bottleneckDistApprox(diagram_A, diagram_B, ts.delta, longest_edge, true); + else + hera_answer = hera::bottleneckDistExact(diagram_A, diagram_B); std::pair<int, int> hera_le { longest_edge.first.get_user_id(), longest_edge.second.get_user_id() }; + // cannot store exact answer in test_list.txt, but need to make sure that Exact is called + if (ts.delta == 0.0) + ts.delta = 0.00000001; + REQUIRE((hera_answer == ts.answer or fabs(hera_answer - ts.answer) <= ts.delta * hera_answer)); REQUIRE((ts.longest_edges.empty() or std::find(ts.longest_edges.begin(), ts.longest_edges.end(), hera_le) != ts.longest_edges.end())); @@ -604,7 +615,6 @@ TEST_CASE("file cases", "bottleneck_dist") // << hera_answer_exact << std::endl; // } REQUIRE( (not check_longest_edge_cost or fabs(hera_le_cost - hera_answer_exact) < 0.0001 * hera_answer_exact) ); - std::cout << ts << " PASSED " << std::endl; } } diff --git a/matching/include/matching_distance.hpp b/matching/include/matching_distance.hpp index f7f44a5..9beab1f 100644 --- a/matching/include/matching_distance.hpp +++ b/matching/include/matching_distance.hpp @@ -362,6 +362,7 @@ namespace md { // TODO: think about this - how to call Hera auto dgm_a = module_a_.weighted_slice_diagram(line); auto dgm_b = module_b_.weighted_slice_diagram(line); + R result; if (params_.hera_epsilon > static_cast<R>(0)) { result = hera::bottleneckDistApprox(dgm_a, dgm_b, params_.hera_epsilon) / ( params_.hera_epsilon + 1); diff --git a/matching/include/persistence_module.h b/matching/include/persistence_module.h index b68c21e..4df2148 100644 --- a/matching/include/persistence_module.h +++ b/matching/include/persistence_module.h @@ -11,6 +11,7 @@ #include "phat/boundary_matrix.h" #include "phat/compute_persistence_pairs.h" +#include "phat/algorithms/standard_reduction.h" #include "common_util.h" #include "dual_point.h" diff --git a/matching/include/persistence_module.hpp b/matching/include/persistence_module.hpp index 7479e02..022b43d 100644 --- a/matching/include/persistence_module.hpp +++ b/matching/include/persistence_module.hpp @@ -158,7 +158,7 @@ namespace md { get_slice_projection_matrix(slice, phat_matrix, gen_projections, rel_projections); phat::persistence_pairs phat_persistence_pairs; - phat::compute_persistence_pairs<phat::twist_reduction>(phat_persistence_pairs, phat_matrix); + phat::compute_persistence_pairs<phat::standard_reduction>(phat_persistence_pairs, phat_matrix, true); Diagram<Real> dgm; diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h index db6ce11..142fcbb 100644 --- a/wasserstein/include/wasserstein.h +++ b/wasserstein/include/wasserstein.h @@ -73,21 +73,19 @@ namespace ws template<class PairContainer> inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2) { - if (dgm1.size() != dgm2.size()) { - return false; - } - using Traits = typename hera::DiagramTraits<PairContainer>; using PointType = typename Traits::PointType; std::map<PointType, int> m1, m2; for(auto&& pair1 : dgm1) { - m1[pair1]++; + if (Traits::get_x(pair1) != Traits::get_y(pair1)) + m1[pair1]++; } for(auto&& pair2 : dgm2) { - m2[pair2]++; + if (Traits::get_x(pair2) != Traits::get_y(pair2)) + m2[pair2]++; } return m1 == m2; @@ -255,6 +253,9 @@ wasserstein_cost(const PairContainer& A, //using PointType = typename Traits::PointType; using RealType = typename Traits::RealType; + constexpr RealType plus_inf = std::numeric_limits<RealType>::infinity(); + constexpr RealType minus_inf = -std::numeric_limits<RealType>::infinity(); + if (hera::ws::are_equal(A, B)) { return 0.0; } @@ -270,19 +271,34 @@ wasserstein_cost(const PairContainer& A, // coordinates of points at infinity std::vector<RealType> x_plus_A, x_minus_A, y_plus_A, y_minus_A; std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B; + // points with both coordinates infinite are treated as equal + int n_minus_inf_plus_inf_A = 0; + int n_plus_inf_minus_inf_A = 0; + int n_minus_inf_plus_inf_B = 0; + int n_plus_inf_minus_inf_B = 0; // loop over A, add projections of A-points to corresponding positions // in B-vector for(auto&& pair_A : A) { a_empty = false; RealType x = Traits::get_x(pair_A); RealType y = Traits::get_y(pair_A); - if ( x == std::numeric_limits<RealType>::infinity()) { + + // skip diagonal points, including (inf, inf), (-inf, -inf) + if (x == y) { + continue; + } + + if (x == plus_inf && y == minus_inf) { + n_plus_inf_minus_inf_A++; + } else if (x == minus_inf && y == plus_inf) { + n_minus_inf_plus_inf_A++; + } else if ( x == plus_inf) { y_plus_A.push_back(y); - } else if (x == -std::numeric_limits<RealType>::infinity()) { + } else if (x == minus_inf) { y_minus_A.push_back(y); - } else if (y == std::numeric_limits<RealType>::infinity()) { + } else if (y == plus_inf) { x_plus_A.push_back(x); - } else if (y == -std::numeric_limits<RealType>::infinity()) { + } else if (y == minus_inf) { x_minus_A.push_back(x); } else { dgm_A.emplace_back(x, y, DgmPoint::NORMAL); @@ -295,13 +311,22 @@ wasserstein_cost(const PairContainer& A, b_empty = false; RealType x = Traits::get_x(pair_B); RealType y = Traits::get_y(pair_B); - if (x == std::numeric_limits<RealType>::infinity()) { + + if (x == y) { + continue; + } + + if (x == plus_inf && y == minus_inf) { + n_plus_inf_minus_inf_B++; + } else if (x == minus_inf && y == plus_inf) { + n_minus_inf_plus_inf_B++; + } else if (x == plus_inf) { y_plus_B.push_back(y); - } else if (x == -std::numeric_limits<RealType>::infinity()) { + } else if (x == minus_inf) { y_minus_B.push_back(y); - } else if (y == std::numeric_limits<RealType>::infinity()) { + } else if (y == plus_inf) { x_plus_B.push_back(x); - } else if (y == -std::numeric_limits<RealType>::infinity()) { + } else if (y == minus_inf) { x_minus_B.push_back(x); } else { dgm_A.emplace_back(x, y, DgmPoint::DIAG); @@ -310,10 +335,16 @@ wasserstein_cost(const PairContainer& A, } } - RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + RealType infinity_cost = 0; + + if (n_plus_inf_minus_inf_A != n_plus_inf_minus_inf_B || n_minus_inf_plus_inf_A != n_minus_inf_plus_inf_B) + infinity_cost = plus_inf; + else { + infinity_cost += ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + } if (a_empty) return total_cost_B + infinity_cost; @@ -321,8 +352,7 @@ wasserstein_cost(const PairContainer& A, if (b_empty) return total_cost_A + infinity_cost; - - if (infinity_cost == std::numeric_limits<RealType>::infinity()) { + if (infinity_cost == plus_inf) { return infinity_cost; } else { return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix); diff --git a/wasserstein/tests/data/test_inf_1_A b/wasserstein/tests/data/test_inf_1_A new file mode 100644 index 0000000..c773f02 --- /dev/null +++ b/wasserstein/tests/data/test_inf_1_A @@ -0,0 +1,3 @@ +-inf inf +-inf inf +2 1 diff --git a/wasserstein/tests/data/test_inf_1_B b/wasserstein/tests/data/test_inf_1_B new file mode 100644 index 0000000..a55e496 --- /dev/null +++ b/wasserstein/tests/data/test_inf_1_B @@ -0,0 +1,4 @@ +-inf inf +-inf inf +inf inf +4 9 diff --git a/wasserstein/tests/data/test_inf_2_A b/wasserstein/tests/data/test_inf_2_A new file mode 100644 index 0000000..c773f02 --- /dev/null +++ b/wasserstein/tests/data/test_inf_2_A @@ -0,0 +1,3 @@ +-inf inf +-inf inf +2 1 diff --git a/wasserstein/tests/data/test_inf_2_B b/wasserstein/tests/data/test_inf_2_B new file mode 100644 index 0000000..6d7e751 --- /dev/null +++ b/wasserstein/tests/data/test_inf_2_B @@ -0,0 +1,5 @@ +-inf inf +-inf inf +inf -inf +inf inf +4 9 diff --git a/wasserstein/tests/data/test_inf_3_A b/wasserstein/tests/data/test_inf_3_A new file mode 100644 index 0000000..4f3fc2f --- /dev/null +++ b/wasserstein/tests/data/test_inf_3_A @@ -0,0 +1,6 @@ +-inf inf +-inf inf +inf -inf +-inf -inf +-inf -inf +2 1 diff --git a/wasserstein/tests/data/test_inf_3_B b/wasserstein/tests/data/test_inf_3_B new file mode 100644 index 0000000..6d7e751 --- /dev/null +++ b/wasserstein/tests/data/test_inf_3_B @@ -0,0 +1,5 @@ +-inf inf +-inf inf +inf -inf +inf inf +4 9 diff --git a/wasserstein/tests/data/test_list.txt b/wasserstein/tests/data/test_list.txt index 27340d8..b1ba6ed 100644 --- a/wasserstein/tests/data/test_list.txt +++ b/wasserstein/tests/data/test_list.txt @@ -19,3 +19,6 @@ test_100_A test_100_B 3.0 2.0 2.09695346034248 test_diag1_A test_diag1_B 1.0 -1.0 0.0 test_diag2_A test_diag2_B 1.0 -1.0 0.0 test_diag3_A test_diag3_B 1.0 -1.0 0.0 +test_inf_1_A test_inf_1_B 1.0 -1.0 3.0 +test_inf_2_A test_inf_2_B 1.0 -1.0 inf +test_inf_3_A test_inf_3_B 1.0 -1.0 3.0 diff --git a/wasserstein/tests/test_hera_wasserstein.cpp b/wasserstein/tests/test_hera_wasserstein.cpp index 0a80d2f..6f5de3b 100644 --- a/wasserstein/tests/test_hera_wasserstein.cpp +++ b/wasserstein/tests/test_hera_wasserstein.cpp @@ -111,6 +111,18 @@ TEST_CASE("simple cases", "wasserstein_dist") } + SECTION("trivial: two diagrams differing by diagonal point") { + + diagram_A.emplace_back(0.0, 1.0); + diagram_B.emplace_back(0.0, 0.0); + diagram_B.emplace_back(0.0, 1.0); + + double d1 = hera::wasserstein_cost<>(diagram_A, diagram_B, params); + double d2 = hera::wasserstein_cost<>(diagram_B, diagram_A, params); + REQUIRE( fabs(d2) <= 0.00000000001 ); + REQUIRE( fabs(d1) <= 0.00000000001 ); + } + } @@ -130,9 +142,14 @@ TEST_CASE("file cases", "wasserstein_dist") SECTION("from file:") { - const char* file_name = "../tests/data/test_list.txt"; + const char* file_name = "test_list.txt"; std::ifstream f; f.open(file_name); + if (!f.good()) { + std::cerr << "Must run from tests/data" << std::endl; + REQUIRE(false); + } + std::vector<TestFromFileCase> test_params; std::string s; while (std::getline(f, s)) { @@ -147,13 +164,13 @@ TEST_CASE("file cases", "wasserstein_dist") REQUIRE( read_file_A ); REQUIRE( read_file_B ); double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params); - REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer ); - std::cout << ts << " PASSED " << std::endl; + bool is_correct = (hera_answer == ts.answer) || (fabs(hera_answer - ts.answer) <= 0.01 * hera_answer); + REQUIRE(is_correct); } } SECTION("from DIPHA file:") { - const char* file_name = "../tests/data/test_list.txt"; + const char* file_name = "test_list.txt"; std::ifstream f; f.open(file_name); std::vector<TestFromFileCase> test_params; @@ -167,8 +184,8 @@ TEST_CASE("file cases", "wasserstein_dist") params.internal_p = ts.internal_p; bool read_file_A = hera::read_diagram_dipha<double, PairVector>(ts.file_1 + std::string(".pd.dipha"), 1, diagram_A); bool read_file_B = hera::read_diagram_dipha<double, PairVector>(ts.file_2 + std::string(".pd.dipha"), 1, diagram_B); - REQUIRE( read_file_A ); - REQUIRE( read_file_B ); + if (!read_file_A) + continue; double hera_answer = hera::wasserstein_dist(diagram_A, diagram_B, params); REQUIRE( fabs(hera_answer - ts.answer) <= 0.01 * hera_answer ); std::cout << ts << " PASSED " << std::endl; |