diff options
author | Gard Spreemann <gspr@nonempty.org> | 2021-08-14 18:32:59 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2021-08-14 18:32:59 +0200 |
commit | 66702d9cf122703964dbe22319ae8d97424d496f (patch) | |
tree | 08681d5c5b5878ed4283d5fba2cbb8f4612dbf7c /wasserstein/include | |
parent | 069338dfb03b4d04c1410b3e24b762b18db5c233 (diff) | |
parent | 2ed9afc052bee7956f6abb195947de1f80cb9d91 (diff) |
Merge branch 'upstream/latest' into dfsg/latest
Diffstat (limited to 'wasserstein/include')
-rw-r--r-- | wasserstein/include/wasserstein.h | 70 |
1 files changed, 50 insertions, 20 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h index db6ce11..142fcbb 100644 --- a/wasserstein/include/wasserstein.h +++ b/wasserstein/include/wasserstein.h @@ -73,21 +73,19 @@ namespace ws template<class PairContainer> inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2) { - if (dgm1.size() != dgm2.size()) { - return false; - } - using Traits = typename hera::DiagramTraits<PairContainer>; using PointType = typename Traits::PointType; std::map<PointType, int> m1, m2; for(auto&& pair1 : dgm1) { - m1[pair1]++; + if (Traits::get_x(pair1) != Traits::get_y(pair1)) + m1[pair1]++; } for(auto&& pair2 : dgm2) { - m2[pair2]++; + if (Traits::get_x(pair2) != Traits::get_y(pair2)) + m2[pair2]++; } return m1 == m2; @@ -255,6 +253,9 @@ wasserstein_cost(const PairContainer& A, //using PointType = typename Traits::PointType; using RealType = typename Traits::RealType; + constexpr RealType plus_inf = std::numeric_limits<RealType>::infinity(); + constexpr RealType minus_inf = -std::numeric_limits<RealType>::infinity(); + if (hera::ws::are_equal(A, B)) { return 0.0; } @@ -270,19 +271,34 @@ wasserstein_cost(const PairContainer& A, // coordinates of points at infinity std::vector<RealType> x_plus_A, x_minus_A, y_plus_A, y_minus_A; std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B; + // points with both coordinates infinite are treated as equal + int n_minus_inf_plus_inf_A = 0; + int n_plus_inf_minus_inf_A = 0; + int n_minus_inf_plus_inf_B = 0; + int n_plus_inf_minus_inf_B = 0; // loop over A, add projections of A-points to corresponding positions // in B-vector for(auto&& pair_A : A) { a_empty = false; RealType x = Traits::get_x(pair_A); RealType y = Traits::get_y(pair_A); - if ( x == std::numeric_limits<RealType>::infinity()) { + + // skip diagonal points, including (inf, inf), (-inf, -inf) + if (x == y) { + continue; + } + + if (x == plus_inf && y == minus_inf) { + n_plus_inf_minus_inf_A++; + } else if (x == minus_inf && y == plus_inf) { + n_minus_inf_plus_inf_A++; + } else if ( x == plus_inf) { y_plus_A.push_back(y); - } else if (x == -std::numeric_limits<RealType>::infinity()) { + } else if (x == minus_inf) { y_minus_A.push_back(y); - } else if (y == std::numeric_limits<RealType>::infinity()) { + } else if (y == plus_inf) { x_plus_A.push_back(x); - } else if (y == -std::numeric_limits<RealType>::infinity()) { + } else if (y == minus_inf) { x_minus_A.push_back(x); } else { dgm_A.emplace_back(x, y, DgmPoint::NORMAL); @@ -295,13 +311,22 @@ wasserstein_cost(const PairContainer& A, b_empty = false; RealType x = Traits::get_x(pair_B); RealType y = Traits::get_y(pair_B); - if (x == std::numeric_limits<RealType>::infinity()) { + + if (x == y) { + continue; + } + + if (x == plus_inf && y == minus_inf) { + n_plus_inf_minus_inf_B++; + } else if (x == minus_inf && y == plus_inf) { + n_minus_inf_plus_inf_B++; + } else if (x == plus_inf) { y_plus_B.push_back(y); - } else if (x == -std::numeric_limits<RealType>::infinity()) { + } else if (x == minus_inf) { y_minus_B.push_back(y); - } else if (y == std::numeric_limits<RealType>::infinity()) { + } else if (y == plus_inf) { x_plus_B.push_back(x); - } else if (y == -std::numeric_limits<RealType>::infinity()) { + } else if (y == minus_inf) { x_minus_B.push_back(x); } else { dgm_A.emplace_back(x, y, DgmPoint::DIAG); @@ -310,10 +335,16 @@ wasserstein_cost(const PairContainer& A, } } - RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); - infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + RealType infinity_cost = 0; + + if (n_plus_inf_minus_inf_A != n_plus_inf_minus_inf_B || n_minus_inf_plus_inf_A != n_minus_inf_plus_inf_B) + infinity_cost = plus_inf; + else { + infinity_cost += ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + } if (a_empty) return total_cost_B + infinity_cost; @@ -321,8 +352,7 @@ wasserstein_cost(const PairContainer& A, if (b_empty) return total_cost_A + infinity_cost; - - if (infinity_cost == std::numeric_limits<RealType>::infinity()) { + if (infinity_cost == plus_inf) { return infinity_cost; } else { return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix); |