diff options
author | tlacombe <lacombe1993@gmail.com> | 2021-04-27 12:16:22 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2021-04-27 12:16:22 +0200 |
commit | c1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd (patch) | |
tree | 91997aceff0b54e6c1e36d33b4f3f55ce0a2571a /src | |
parent | f3307644119172ea99ded6da94e57869cf9f1981 (diff) |
now consider (inf,inf) as belonging to the diagonal ; more tests
Diffstat (limited to 'src')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 18 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 36 |
2 files changed, 46 insertions, 8 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 3abecfe6..5095e672 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -106,6 +106,8 @@ def _get_essential_parts(a): .. note:: For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)). + + Finally, we consider that points with coordinates (-inf,-inf) and (+inf, +inf) belong to the diagonal. ''' if len(a): first_coord_finite = np.isfinite(a[:,0]) @@ -118,6 +120,7 @@ def _get_essential_parts(a): ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x) ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf) ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf) + ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf) ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type @@ -162,7 +165,7 @@ def _handle_essential_parts(X, Y, order): ess_parts_Y = _get_essential_parts(Y) # Treats the case of infinite cost (cardinalities of essential parts differ). - for u, v in zip(ess_parts_X, ess_parts_Y): + for u, v in list(zip(ess_parts_X, ess_parts_Y))[:3]: # ignore types 4 and 5 as they belong to the diagonal if len(u) != len(v): return np.inf, None @@ -174,9 +177,14 @@ def _handle_essential_parts(X, Y, order): c = c1 + c2 m = m1 + m2 - # Handle type >= 2 (both coordinates are infinite, so we essentially just align points) - for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]): - m += list(zip(u, v)) # cost is 0 + # Handle type3 (coordinates (-inf,+inf), so we just align points) + m += list(zip(ess_parts_X[2], ess_parts_Y[2])) + + # Handle type 4 and 5, considered as belonging to the diagonal so matched to (-1) with cost 0. + for z in ess_parts_X[3:]: + m += [(u, -1) for u in z] # points in X are matched to -1 + for z in ess_parts_Y[3:]: + m += [(-1, v) for v in z] # -1 is match to points in Y return c, np.array(m) @@ -334,7 +342,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return ep.concatenate(dists).norms.lp(order).raw # We can also concatenate the 3 vectors to compute just one norm. - # Comptuation of the otcost using the ot.emd2 library. + # Comptuation of the ot cost using the ot.emd2 library. # Note: it is the Wasserstein distance to the power q. # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? ot_cost = ot.emd2(a, b, M, numItermax=2000000) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 121ba065..3a004d77 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -10,6 +10,7 @@ """ from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts +from gudhi.wasserstein.wasserstein import _warn_infty from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -50,16 +51,17 @@ def test_handle_essential_parts(): [-np.inf, np.inf], [-np.inf, np.inf]]) diag3 = np.array([[0, 2], [3, 5], - [2, np.inf], [4, np.inf], + [2, np.inf], [4, np.inf], [6, np.inf], [-np.inf, 8], [-np.inf, 11], - [-np.inf, -np.inf], [-np.inf, -np.inf], + [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) c, m = _handle_essential_parts(diag1, diag2, order=1) assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3) # Similarly, the matching only corresponds to essential parts. - assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]]) + # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal. + assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]]) c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf @@ -87,6 +89,13 @@ def test_get_essential_parts(): assert np.array_equal(res2[4], [] ) +def test_warn_infty(): + assert _warn_infty(matching=False)==np.inf + c, m = _warn_infty(matching=True) + assert (c == np.inf) + assert (m is None) + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -143,11 +152,29 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat if test_matching and test_infinity: diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) + diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]]) + diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]]) + diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]]) match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] assert (match is None) + cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.) + assert (cost == 1) + assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. + cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.) + assert (cost == 0.) + assert (match == [[0, -1], [1, -1]]) def hera_wrap(**extra): @@ -155,15 +182,18 @@ def hera_wrap(**extra): return hera(*kargs,**kwargs,**extra) return fun + def pot_wrap(**extra): def fun(*kargs,**kwargs): return pot(*kargs,**kwargs,**extra) return fun + def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) + def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) |