summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-27 12:16:22 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-27 12:16:22 +0200
commitc1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd (patch)
tree91997aceff0b54e6c1e36d33b4f3f55ce0a2571a
parentf3307644119172ea99ded6da94e57869cf9f1981 (diff)
now consider (inf,inf) as belonging to the diagonal ; more tests
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py18
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py36
2 files changed, 46 insertions, 8 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 3abecfe6..5095e672 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -106,6 +106,8 @@ def _get_essential_parts(a):
.. note::
For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x.
Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)).
+
+ Finally, we consider that points with coordinates (-inf,-inf) and (+inf, +inf) belong to the diagonal.
'''
if len(a):
first_coord_finite = np.isfinite(a[:,0])
@@ -118,6 +120,7 @@ def _get_essential_parts(a):
ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x)
ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf)
ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf)
+
ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf)
ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf)
return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type
@@ -162,7 +165,7 @@ def _handle_essential_parts(X, Y, order):
ess_parts_Y = _get_essential_parts(Y)
# Treats the case of infinite cost (cardinalities of essential parts differ).
- for u, v in zip(ess_parts_X, ess_parts_Y):
+ for u, v in list(zip(ess_parts_X, ess_parts_Y))[:3]: # ignore types 4 and 5 as they belong to the diagonal
if len(u) != len(v):
return np.inf, None
@@ -174,9 +177,14 @@ def _handle_essential_parts(X, Y, order):
c = c1 + c2
m = m1 + m2
- # Handle type >= 2 (both coordinates are infinite, so we essentially just align points)
- for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]):
- m += list(zip(u, v)) # cost is 0
+ # Handle type3 (coordinates (-inf,+inf), so we just align points)
+ m += list(zip(ess_parts_X[2], ess_parts_Y[2]))
+
+ # Handle type 4 and 5, considered as belonging to the diagonal so matched to (-1) with cost 0.
+ for z in ess_parts_X[3:]:
+ m += [(u, -1) for u in z] # points in X are matched to -1
+ for z in ess_parts_Y[3:]:
+ m += [(-1, v) for v in z] # -1 is match to points in Y
return c, np.array(m)
@@ -334,7 +342,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
return ep.concatenate(dists).norms.lp(order).raw
# We can also concatenate the 3 vectors to compute just one norm.
- # Comptuation of the otcost using the ot.emd2 library.
+ # Comptuation of the ot cost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
ot_cost = ot.emd2(a, b, M, numItermax=2000000)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 121ba065..3a004d77 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -10,6 +10,7 @@
"""
from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts
+from gudhi.wasserstein.wasserstein import _warn_infty
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
@@ -50,16 +51,17 @@ def test_handle_essential_parts():
[-np.inf, np.inf], [-np.inf, np.inf]])
diag3 = np.array([[0, 2], [3, 5],
- [2, np.inf], [4, np.inf],
+ [2, np.inf], [4, np.inf], [6, np.inf],
[-np.inf, 8], [-np.inf, 11],
- [-np.inf, -np.inf], [-np.inf, -np.inf],
+ [-np.inf, -np.inf],
[np.inf, np.inf],
[-np.inf, np.inf], [-np.inf, np.inf]])
c, m = _handle_essential_parts(diag1, diag2, order=1)
assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3)
# Similarly, the matching only corresponds to essential parts.
- assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]])
+ # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal.
+ assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]])
c, m = _handle_essential_parts(diag1, diag3, order=1)
assert c == np.inf
@@ -87,6 +89,13 @@ def test_get_essential_parts():
assert np.array_equal(res2[4], [] )
+def test_warn_infty():
+ assert _warn_infty(matching=False)==np.inf
+ c, m = _warn_infty(matching=True)
+ assert (c == np.inf)
+ assert (m is None)
+
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
@@ -143,11 +152,29 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
if test_matching and test_infinity:
diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
+ diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]])
+ diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]])
+ diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
assert (match is None)
+ cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
+ assert (cost == 1)
+ assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
+ assert (cost == 0.)
+ assert (match == [[0, -1], [1, -1]])
def hera_wrap(**extra):
@@ -155,15 +182,18 @@ def hera_wrap(**extra):
return hera(*kargs,**kwargs,**extra)
return fun
+
def pot_wrap(**extra):
def fun(*kargs,**kwargs):
return pot(*kargs,**kwargs,**extra)
return fun
+
def test_wasserstein_distance_pot():
_basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args
_basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False)
+
def test_wasserstein_distance_hera():
_basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
_basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)