summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/doc/wasserstein_distance_user.rst24
-rw-r--r--src/python/gudhi/wasserstein.py57
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py33
3 files changed, 89 insertions, 25 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 94b454e2..9519caa6 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -47,3 +47,27 @@ The output is:
.. testoutput::
Wasserstein distance value = 1.45
+
+We can also have access to the optimal matching by letting `matching=True`.
+It is encoded as a list of indices (i,j), meaning that the i-th point in X
+is mapped to the j-th point in Y.
+An index of -1 represents the diagonal.
+
+.. testcode::
+
+ import gudhi.wasserstein
+ import numpy as np
+
+ diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
+ cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.)
+
+ message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching)
+ print(message)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein distance value = 2.15, optimal matching: [[0, 0], [1, 2], [2, -1], [-1, 1]]
+
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index 13102094..3dd993f9 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -30,8 +30,10 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.):
:param order: exponent for the Wasserstein metric.
:param internal_p: Ground metric (i.e. norm L^p).
:returns: (n+1) x (m+1) np.array encoding the cost matrix C.
- For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal.
- note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal).
+ For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j],
+ while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j])
+ and its orthogonal projection onto the diagonal.
+ note also that C[n, m] = 0 (it costs nothing to move from the diagonal to the diagonal).
'''
Xdiag = _proj_on_diag(X)
Ydiag = _proj_on_diag(Y)
@@ -62,14 +64,20 @@ def _perstot(X, order, internal_p):
return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
-def wasserstein_distance(X, Y, order=2., internal_p=2.):
+def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
'''
- :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
+ :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
+ (i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
+ :param matching: if True, computes and returns the optimal matching between X and Y, encoded as
+ a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
+ the j-th point in Y, with the convention (-1) represents the diagonal.
:param order: exponent for Wasserstein; Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
- :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric.
- :rtype: float
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
+ Default value is 2 (Euclidean norm).
+ :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
+ respect to the internal_p-norm as ground metric.
+ If matching is set to True, also returns the optimal matching between X and Y.
'''
n = len(X)
m = len(Y)
@@ -77,21 +85,40 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.):
# handle empty diagrams
if X.size == 0:
if Y.size == 0:
- return 0.
+ if not matching:
+ return 0.
+ else:
+ return 0., np.array([])
else:
- return _perstot(Y, order, internal_p)
+ if not matching:
+ return _perstot(Y, order, internal_p)
+ else:
+ return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)])
elif Y.size == 0:
- return _perstot(X, order, internal_p)
+ if not matching:
+ return _perstot(X, order, internal_p)
+ else:
+ return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)])
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
- a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT
- b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- b[-1] = b[-1] * n # so that we have a probability measure, required by POT
+ a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
+ a[-1] = m
+ b = np.ones(m+1) # weight vector of the input diagram. Uniform here.
+ b[-1] = n
+
+ if matching:
+ P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
+ ot_cost = np.sum(np.multiply(P,M))
+ P[-1, -1] = 0 # Remove matching corresponding to the diagonal
+ match = np.argwhere(P)
+ # Now we turn to -1 points encoding the diagonal
+ match[:,0][match[:,0] >= n] = -1
+ match[:,1][match[:,1] >= m] = -1
+ return ot_cost ** (1./order) , match
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
- ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000)
+ ot_cost = ot.emd2(a, b, M, numItermax=2000000)
return ot_cost ** (1./order)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 6a6b217b..0d70e11a 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -17,7 +17,7 @@ __author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
-def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
+def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
diag3 = np.array([[0, 2], [4, 6]])
@@ -51,14 +51,27 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
- if(not test_infinity):
- return
+ if test_infinity:
+ diag5 = np.array([[0, 3], [4, np.inf]])
+ diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
- diag5 = np.array([[0, 3], [4, np.inf]])
- diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
+ assert wasserstein_distance(diag4, diag5) == np.inf
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
+
+
+ if test_matching:
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
+ assert np.array_equal(match, [])
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert np.array_equal(match, [])
+ match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
+ assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert np.array_equal(match , [[0, -1], [1, -1]])
+ match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
+ assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
+
- assert wasserstein_distance(diag4, diag5) == np.inf
- assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
def hera_wrap(delta):
def fun(*kargs,**kwargs):
@@ -66,8 +79,8 @@ def hera_wrap(delta):
return fun
def test_wasserstein_distance_pot():
- _basic_wasserstein(pot, 1e-15, test_infinity=False)
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12)
- _basic_wasserstein(hera_wrap(.1), .1)
+ _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)