summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-03 15:33:17 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-03 15:33:17 +0100
commit8e4f3d151818b78a29d11cdc6ca171947bfd6dd9 (patch)
treef7f2d562332cfa22b90628e0e95dd22739322f9e /src/python/gudhi/wasserstein.py
parentd2943b9e7311c8a3d8a4fb379c39b15497481b9c (diff)
update wasserstein distance with pot so that it can return optimal matching now!
Diffstat (limited to 'src/python/gudhi/wasserstein.py')
-rw-r--r--src/python/gudhi/wasserstein.py69
1 files changed, 56 insertions, 13 deletions
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index 13102094..ba0f7343 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -62,14 +62,39 @@ def _perstot(X, order, internal_p):
return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
-def wasserstein_distance(X, Y, order=2., internal_p=2.):
+def _clean_match(match, n, m):
'''
- :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
+ :param match: a list of the form [(i,j) ...]
+ :param n: int, size of the first dgm
+ :param m: int, size of the second dgm
+ :return: a modified version of match where indices greater than n, m are replaced by -1, encoding the diagonal.
+ and (-1, -1) are removed
+ '''
+ new_match = []
+ for i,j in match:
+ if i >= n:
+ if j < m:
+ new_match.append((-1, j))
+ elif j >= m:
+ if i < n:
+ new_match.append((i,-1))
+ else:
+ new_match.append((i,j))
+ return new_match
+
+
+def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
+ '''
+ :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
+ (i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
+ :param matching: if True, computes and returns the optimal matching between X and Y, encoded as...
:param order: exponent for Wasserstein; Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
- :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric.
- :rtype: float
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
+ Default value is 2 (Euclidean norm).
+ :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
+ respect to the internal_p-norm as ground metric.
+ If matching is set to True, also returns the optimal matching between X and Y.
'''
n = len(X)
m = len(Y)
@@ -77,21 +102,39 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.):
# handle empty diagrams
if X.size == 0:
if Y.size == 0:
- return 0.
+ if not matching:
+ return 0.
+ else:
+ return 0., []
else:
- return _perstot(Y, order, internal_p)
+ if not matching:
+ return _perstot(Y, order, internal_p)
+ else:
+ return _perstot(Y, order, internal_p), [(-1, j) for j in range(m)]
elif Y.size == 0:
- return _perstot(X, order, internal_p)
+ if not matching:
+ return _perstot(X, order, internal_p)
+ else:
+ return _perstot(X, order, internal_p), [(i, -1) for i in range(n)]
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
- a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT
- b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- b[-1] = b[-1] * n # so that we have a probability measure, required by POT
+ a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
+ a[-1] = m
+ b = np.ones(m+1) # weight vector of the input diagram. Uniform here.
+ b[-1] = n
+
+ if matching:
+ P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
+ ot_cost = np.sum(np.multiply(P,M))
+ P[P < 0.5] = 0 # trick to avoid numerical issue, could it be improved?
+ match = np.argwhere(P)
+ # Now we turn to -1 points encoding the diagonal
+ match = _clean_match(match, n, m)
+ return ot_cost ** (1./order) , match
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
- ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000)
+ ot_cost = ot.emd2(a, b, M, numItermax=2000000)
return ot_cost ** (1./order)