From c1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 27 Apr 2021 12:16:22 +0200 Subject: now consider (inf,inf) as belonging to the diagonal ; more tests --- src/python/gudhi/wasserstein/wasserstein.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'src/python/gudhi/wasserstein') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 3abecfe6..5095e672 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -106,6 +106,8 @@ def _get_essential_parts(a): .. note:: For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)). + + Finally, we consider that points with coordinates (-inf,-inf) and (+inf, +inf) belong to the diagonal. ''' if len(a): first_coord_finite = np.isfinite(a[:,0]) @@ -118,6 +120,7 @@ def _get_essential_parts(a): ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x) ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf) ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf) + ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf) ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type @@ -162,7 +165,7 @@ def _handle_essential_parts(X, Y, order): ess_parts_Y = _get_essential_parts(Y) # Treats the case of infinite cost (cardinalities of essential parts differ). - for u, v in zip(ess_parts_X, ess_parts_Y): + for u, v in list(zip(ess_parts_X, ess_parts_Y))[:3]: # ignore types 4 and 5 as they belong to the diagonal if len(u) != len(v): return np.inf, None @@ -174,9 +177,14 @@ def _handle_essential_parts(X, Y, order): c = c1 + c2 m = m1 + m2 - # Handle type >= 2 (both coordinates are infinite, so we essentially just align points) - for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]): - m += list(zip(u, v)) # cost is 0 + # Handle type3 (coordinates (-inf,+inf), so we just align points) + m += list(zip(ess_parts_X[2], ess_parts_Y[2])) + + # Handle type 4 and 5, considered as belonging to the diagonal so matched to (-1) with cost 0. + for z in ess_parts_X[3:]: + m += [(u, -1) for u in z] # points in X are matched to -1 + for z in ess_parts_Y[3:]: + m += [(-1, v) for v in z] # -1 is match to points in Y return c, np.array(m) @@ -334,7 +342,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return ep.concatenate(dists).norms.lp(order).raw # We can also concatenate the 3 vectors to compute just one norm. - # Comptuation of the otcost using the ot.emd2 library. + # Comptuation of the ot cost using the ot.emd2 library. # Note: it is the Wasserstein distance to the power q. # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? ot_cost = ot.emd2(a, b, M, numItermax=2000000) -- cgit v1.2.3