summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein/wasserstein.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-27 12:16:22 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-27 12:16:22 +0200
commitc1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd (patch)
tree91997aceff0b54e6c1e36d33b4f3f55ce0a2571a /src/python/gudhi/wasserstein/wasserstein.py
parentf3307644119172ea99ded6da94e57869cf9f1981 (diff)
now consider (inf,inf) as belonging to the diagonal ; more tests
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 3abecfe6..5095e672 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -106,6 +106,8 @@ def _get_essential_parts(a):
.. note::
For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x.
Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)).
+
+ Finally, we consider that points with coordinates (-inf,-inf) and (+inf, +inf) belong to the diagonal.
'''
if len(a):
first_coord_finite = np.isfinite(a[:,0])
@@ -118,6 +120,7 @@ def _get_essential_parts(a):
ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x)
ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf)
ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf)
+
ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf)
ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf)
return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type
@@ -162,7 +165,7 @@ def _handle_essential_parts(X, Y, order):
ess_parts_Y = _get_essential_parts(Y)
# Treats the case of infinite cost (cardinalities of essential parts differ).
- for u, v in zip(ess_parts_X, ess_parts_Y):
+ for u, v in list(zip(ess_parts_X, ess_parts_Y))[:3]: # ignore types 4 and 5 as they belong to the diagonal
if len(u) != len(v):
return np.inf, None
@@ -174,9 +177,14 @@ def _handle_essential_parts(X, Y, order):
c = c1 + c2
m = m1 + m2
- # Handle type >= 2 (both coordinates are infinite, so we essentially just align points)
- for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]):
- m += list(zip(u, v)) # cost is 0
+ # Handle type3 (coordinates (-inf,+inf), so we just align points)
+ m += list(zip(ess_parts_X[2], ess_parts_Y[2]))
+
+ # Handle type 4 and 5, considered as belonging to the diagonal so matched to (-1) with cost 0.
+ for z in ess_parts_X[3:]:
+ m += [(u, -1) for u in z] # points in X are matched to -1
+ for z in ess_parts_Y[3:]:
+ m += [(-1, v) for v in z] # -1 is match to points in Y
return c, np.array(m)
@@ -334,7 +342,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
return ep.concatenate(dists).norms.lp(order).raw
# We can also concatenate the 3 vectors to compute just one norm.
- # Comptuation of the otcost using the ot.emd2 library.
+ # Comptuation of the ot cost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
ot_cost = ot.emd2(a, b, M, numItermax=2000000)