summaryrefslogtreecommitdiff
path: root/wasserstein/include/wasserstein.h
diff options
context:
space:
mode:
Diffstat (limited to 'wasserstein/include/wasserstein.h')
-rw-r--r--wasserstein/include/wasserstein.h70
1 files changed, 50 insertions, 20 deletions
diff --git a/wasserstein/include/wasserstein.h b/wasserstein/include/wasserstein.h
index db6ce11..142fcbb 100644
--- a/wasserstein/include/wasserstein.h
+++ b/wasserstein/include/wasserstein.h
@@ -73,21 +73,19 @@ namespace ws
template<class PairContainer>
inline bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2)
{
- if (dgm1.size() != dgm2.size()) {
- return false;
- }
-
using Traits = typename hera::DiagramTraits<PairContainer>;
using PointType = typename Traits::PointType;
std::map<PointType, int> m1, m2;
for(auto&& pair1 : dgm1) {
- m1[pair1]++;
+ if (Traits::get_x(pair1) != Traits::get_y(pair1))
+ m1[pair1]++;
}
for(auto&& pair2 : dgm2) {
- m2[pair2]++;
+ if (Traits::get_x(pair2) != Traits::get_y(pair2))
+ m2[pair2]++;
}
return m1 == m2;
@@ -255,6 +253,9 @@ wasserstein_cost(const PairContainer& A,
//using PointType = typename Traits::PointType;
using RealType = typename Traits::RealType;
+ constexpr RealType plus_inf = std::numeric_limits<RealType>::infinity();
+ constexpr RealType minus_inf = -std::numeric_limits<RealType>::infinity();
+
if (hera::ws::are_equal(A, B)) {
return 0.0;
}
@@ -270,19 +271,34 @@ wasserstein_cost(const PairContainer& A,
// coordinates of points at infinity
std::vector<RealType> x_plus_A, x_minus_A, y_plus_A, y_minus_A;
std::vector<RealType> x_plus_B, x_minus_B, y_plus_B, y_minus_B;
+ // points with both coordinates infinite are treated as equal
+ int n_minus_inf_plus_inf_A = 0;
+ int n_plus_inf_minus_inf_A = 0;
+ int n_minus_inf_plus_inf_B = 0;
+ int n_plus_inf_minus_inf_B = 0;
// loop over A, add projections of A-points to corresponding positions
// in B-vector
for(auto&& pair_A : A) {
a_empty = false;
RealType x = Traits::get_x(pair_A);
RealType y = Traits::get_y(pair_A);
- if ( x == std::numeric_limits<RealType>::infinity()) {
+
+ // skip diagonal points, including (inf, inf), (-inf, -inf)
+ if (x == y) {
+ continue;
+ }
+
+ if (x == plus_inf && y == minus_inf) {
+ n_plus_inf_minus_inf_A++;
+ } else if (x == minus_inf && y == plus_inf) {
+ n_minus_inf_plus_inf_A++;
+ } else if ( x == plus_inf) {
y_plus_A.push_back(y);
- } else if (x == -std::numeric_limits<RealType>::infinity()) {
+ } else if (x == minus_inf) {
y_minus_A.push_back(y);
- } else if (y == std::numeric_limits<RealType>::infinity()) {
+ } else if (y == plus_inf) {
x_plus_A.push_back(x);
- } else if (y == -std::numeric_limits<RealType>::infinity()) {
+ } else if (y == minus_inf) {
x_minus_A.push_back(x);
} else {
dgm_A.emplace_back(x, y, DgmPoint::NORMAL);
@@ -295,13 +311,22 @@ wasserstein_cost(const PairContainer& A,
b_empty = false;
RealType x = Traits::get_x(pair_B);
RealType y = Traits::get_y(pair_B);
- if (x == std::numeric_limits<RealType>::infinity()) {
+
+ if (x == y) {
+ continue;
+ }
+
+ if (x == plus_inf && y == minus_inf) {
+ n_plus_inf_minus_inf_B++;
+ } else if (x == minus_inf && y == plus_inf) {
+ n_minus_inf_plus_inf_B++;
+ } else if (x == plus_inf) {
y_plus_B.push_back(y);
- } else if (x == -std::numeric_limits<RealType>::infinity()) {
+ } else if (x == minus_inf) {
y_minus_B.push_back(y);
- } else if (y == std::numeric_limits<RealType>::infinity()) {
+ } else if (y == plus_inf) {
x_plus_B.push_back(x);
- } else if (y == -std::numeric_limits<RealType>::infinity()) {
+ } else if (y == minus_inf) {
x_minus_B.push_back(x);
} else {
dgm_A.emplace_back(x, y, DgmPoint::DIAG);
@@ -310,10 +335,16 @@ wasserstein_cost(const PairContainer& A,
}
}
- RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power);
- infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power);
+ RealType infinity_cost = 0;
+
+ if (n_plus_inf_minus_inf_A != n_plus_inf_minus_inf_B || n_minus_inf_plus_inf_A != n_minus_inf_plus_inf_B)
+ infinity_cost = plus_inf;
+ else {
+ infinity_cost += ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power);
+ infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power);
+ }
if (a_empty)
return total_cost_B + infinity_cost;
@@ -321,8 +352,7 @@ wasserstein_cost(const PairContainer& A,
if (b_empty)
return total_cost_A + infinity_cost;
-
- if (infinity_cost == std::numeric_limits<RealType>::infinity()) {
+ if (infinity_cost == plus_inf) {
return infinity_cost;
} else {
return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix);