diff options
author | tlacombe <lacombe1993@gmail.com> | 2021-04-27 12:16:22 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2021-04-27 12:16:22 +0200 |
commit | c1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd (patch) | |
tree | 91997aceff0b54e6c1e36d33b4f3f55ce0a2571a /src/python/test/test_wasserstein_distance.py | |
parent | f3307644119172ea99ded6da94e57869cf9f1981 (diff) |
now consider (inf,inf) as belonging to the diagonal ; more tests
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 121ba065..3a004d77 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -10,6 +10,7 @@ """ from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts +from gudhi.wasserstein.wasserstein import _warn_infty from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -50,16 +51,17 @@ def test_handle_essential_parts(): [-np.inf, np.inf], [-np.inf, np.inf]]) diag3 = np.array([[0, 2], [3, 5], - [2, np.inf], [4, np.inf], + [2, np.inf], [4, np.inf], [6, np.inf], [-np.inf, 8], [-np.inf, 11], - [-np.inf, -np.inf], [-np.inf, -np.inf], + [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) c, m = _handle_essential_parts(diag1, diag2, order=1) assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3) # Similarly, the matching only corresponds to essential parts. - assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]]) + # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal. + assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]]) c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf @@ -87,6 +89,13 @@ def test_get_essential_parts(): assert np.array_equal(res2[4], [] ) +def test_warn_infty(): + assert _warn_infty(matching=False)==np.inf + c, m = _warn_infty(matching=True) + assert (c == np.inf) + assert (m is None) + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -143,11 +152,29 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat if test_matching and test_infinity: diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) + diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]]) + diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]]) + diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]]) match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] assert (match is None) + cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.) + assert (cost == 1) + assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. + cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.) + assert (cost == 0.) + assert (match == [[0, -1], [1, -1]]) def hera_wrap(**extra): @@ -155,15 +182,18 @@ def hera_wrap(**extra): return hera(*kargs,**kwargs,**extra) return fun + def pot_wrap(**extra): def fun(*kargs,**kwargs): return pot(*kargs,**kwargs,**extra) return fun + def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) + def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) |