summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-27 12:16:22 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-27 12:16:22 +0200
commitc1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd (patch)
tree91997aceff0b54e6c1e36d33b4f3f55ce0a2571a /src/python/test/test_wasserstein_distance.py
parentf3307644119172ea99ded6da94e57869cf9f1981 (diff)
now consider (inf,inf) as belonging to the diagonal ; more tests
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py36
1 files changed, 33 insertions, 3 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 121ba065..3a004d77 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -10,6 +10,7 @@
"""
from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts
+from gudhi.wasserstein.wasserstein import _warn_infty
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
@@ -50,16 +51,17 @@ def test_handle_essential_parts():
[-np.inf, np.inf], [-np.inf, np.inf]])
diag3 = np.array([[0, 2], [3, 5],
- [2, np.inf], [4, np.inf],
+ [2, np.inf], [4, np.inf], [6, np.inf],
[-np.inf, 8], [-np.inf, 11],
- [-np.inf, -np.inf], [-np.inf, -np.inf],
+ [-np.inf, -np.inf],
[np.inf, np.inf],
[-np.inf, np.inf], [-np.inf, np.inf]])
c, m = _handle_essential_parts(diag1, diag2, order=1)
assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3)
# Similarly, the matching only corresponds to essential parts.
- assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]])
+ # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal.
+ assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]])
c, m = _handle_essential_parts(diag1, diag3, order=1)
assert c == np.inf
@@ -87,6 +89,13 @@ def test_get_essential_parts():
assert np.array_equal(res2[4], [] )
+def test_warn_infty():
+ assert _warn_infty(matching=False)==np.inf
+ c, m = _warn_infty(matching=True)
+ assert (c == np.inf)
+ assert (m is None)
+
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
@@ -143,11 +152,29 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
if test_matching and test_infinity:
diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
+ diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]])
+ diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]])
+ diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
assert (match is None)
+ cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.)
+ assert (cost == np.inf)
+ assert (match is None)
+ cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
+ assert (cost == 1)
+ assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
+ assert (cost == 0.)
+ assert (match == [[0, -1], [1, -1]])
def hera_wrap(**extra):
@@ -155,15 +182,18 @@ def hera_wrap(**extra):
return hera(*kargs,**kwargs,**extra)
return fun
+
def pot_wrap(**extra):
def fun(*kargs,**kwargs):
return pot(*kargs,**kwargs,**extra)
return fun
+
def test_wasserstein_distance_pot():
_basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args
_basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False)
+
def test_wasserstein_distance_hera():
_basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
_basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)