summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/doc/wasserstein_distance_user.rst24
-rw-r--r--src/python/gudhi/wasserstein.py69
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py31
3 files changed, 102 insertions, 22 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 94b454e2..d3daa318 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -47,3 +47,27 @@ The output is:
.. testoutput::
Wasserstein distance value = 1.45
+
+We can also have access to the optimal matching by letting `matching=True`.
+It is encoded as a list of indices (i,j), meaning that the i-th point in X
+is mapped to the j-th point in Y.
+An index of -1 represents the diagonal.
+
+.. testcode::
+
+ import gudhi.wasserstein
+ import numpy as np
+
+ diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
+ cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.)
+
+ message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching)
+ print(message)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein distance value = 2.15, optimal matching: [(0, 0), (1, 2), (2, -1), (-1, 1)]
+
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index 13102094..ba0f7343 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -62,14 +62,39 @@ def _perstot(X, order, internal_p):
return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
-def wasserstein_distance(X, Y, order=2., internal_p=2.):
+def _clean_match(match, n, m):
'''
- :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
+ :param match: a list of the form [(i,j) ...]
+ :param n: int, size of the first dgm
+ :param m: int, size of the second dgm
+ :return: a modified version of match where indices greater than n, m are replaced by -1, encoding the diagonal.
+ and (-1, -1) are removed
+ '''
+ new_match = []
+ for i,j in match:
+ if i >= n:
+ if j < m:
+ new_match.append((-1, j))
+ elif j >= m:
+ if i < n:
+ new_match.append((i,-1))
+ else:
+ new_match.append((i,j))
+ return new_match
+
+
+def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
+ '''
+ :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
+ (i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
+ :param matching: if True, computes and returns the optimal matching between X and Y, encoded as...
:param order: exponent for Wasserstein; Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
- :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric.
- :rtype: float
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
+ Default value is 2 (Euclidean norm).
+ :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
+ respect to the internal_p-norm as ground metric.
+ If matching is set to True, also returns the optimal matching between X and Y.
'''
n = len(X)
m = len(Y)
@@ -77,21 +102,39 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.):
# handle empty diagrams
if X.size == 0:
if Y.size == 0:
- return 0.
+ if not matching:
+ return 0.
+ else:
+ return 0., []
else:
- return _perstot(Y, order, internal_p)
+ if not matching:
+ return _perstot(Y, order, internal_p)
+ else:
+ return _perstot(Y, order, internal_p), [(-1, j) for j in range(m)]
elif Y.size == 0:
- return _perstot(X, order, internal_p)
+ if not matching:
+ return _perstot(X, order, internal_p)
+ else:
+ return _perstot(X, order, internal_p), [(i, -1) for i in range(n)]
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
- a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT
- b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- b[-1] = b[-1] * n # so that we have a probability measure, required by POT
+ a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
+ a[-1] = m
+ b = np.ones(m+1) # weight vector of the input diagram. Uniform here.
+ b[-1] = n
+
+ if matching:
+ P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
+ ot_cost = np.sum(np.multiply(P,M))
+ P[P < 0.5] = 0 # trick to avoid numerical issue, could it be improved?
+ match = np.argwhere(P)
+ # Now we turn to -1 points encoding the diagonal
+ match = _clean_match(match, n, m)
+ return ot_cost ** (1./order) , match
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
- ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000)
+ ot_cost = ot.emd2(a, b, M, numItermax=2000000)
return ot_cost ** (1./order)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 6a6b217b..02a1d2c9 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -51,14 +51,27 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
- if(not test_infinity):
- return
+ if test_infinity:
+ diag5 = np.array([[0, 3], [4, np.inf]])
+ diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
- diag5 = np.array([[0, 3], [4, np.inf]])
- diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
+ assert wasserstein_distance(diag4, diag5) == np.inf
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
+
+
+ if test_matching:
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
+ assert match == []
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert match == []
+ match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
+ assert match == [(-1, 0), (-1, 1)]
+ match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert match == [(0, -1), (1, -1)]
+ match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
+ assert match == [(0, 0), (1, 1), (2, -1)]
+
- assert wasserstein_distance(diag4, diag5) == np.inf
- assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
def hera_wrap(delta):
def fun(*kargs,**kwargs):
@@ -66,8 +79,8 @@ def hera_wrap(delta):
return fun
def test_wasserstein_distance_pot():
- _basic_wasserstein(pot, 1e-15, test_infinity=False)
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12)
- _basic_wasserstein(hera_wrap(.1), .1)
+ _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)