summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-10 16:47:09 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-10 16:47:09 +0100
commit967ceab26b09ad74e0cff0d84429a766af267f6b (patch)
tree51a06301124c9cc8f536043b8a17699c5671009e /src/python/gudhi/wasserstein.py
parent2eca5c75b1fbd7157e2656b875e730dc5f00f373 (diff)
removed _clean_match and changed matching format, it is now a (n x 2) numpy array
Diffstat (limited to 'src/python/gudhi/wasserstein.py')
-rw-r--r--src/python/gudhi/wasserstein.py31
1 files changed, 6 insertions, 25 deletions
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index e28c63e6..9efa946e 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -64,34 +64,13 @@ def _perstot(X, order, internal_p):
return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
-def _clean_match(match, n, m):
- '''
- :param match: a list of the form [(i,j) ...]
- :param n: int, size of the first dgm
- :param m: int, size of the second dgm
- :return: a modified version of match where indices greater than n, m are replaced by -1, encoding the diagonal.
- and (-1, -1) are removed
- '''
- new_match = []
- for i,j in match:
- if i >= n:
- if j < m:
- new_match.append((-1, j))
- elif j >= m:
- if i < n:
- new_match.append((i,-1))
- else:
- new_match.append((i,j))
- return new_match
-
-
def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
:param matching: if True, computes and returns the optimal matching between X and Y, encoded as
- a list of tuple [...(i,j)...], meaning the i-th point in X is matched to
+ a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
the j-th point in Y, with the convention (-1) represents the diagonal.
:param order: exponent for Wasserstein; Default value is 2.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
@@ -114,12 +93,12 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
if not matching:
return _perstot(Y, order, internal_p)
else:
- return _perstot(Y, order, internal_p), [(-1, j) for j in range(m)]
+ return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)])
elif Y.size == 0:
if not matching:
return _perstot(X, order, internal_p)
else:
- return _perstot(X, order, internal_p), [(i, -1) for i in range(n)]
+ return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)])
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
@@ -130,9 +109,11 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
if matching:
P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
ot_cost = np.sum(np.multiply(P,M))
+ P[-1, -1] = 0 # Remove matching corresponding to the diagonal
match = np.argwhere(P)
# Now we turn to -1 points encoding the diagonal
- match = _clean_match(match, n, m)
+ match[:,0][match[:,0] >= n] = -1
+ match[:,1][match[:,1] >= m] = -1
return ot_cost ** (1./order) , match
# Comptuation of the otcost using the ot.emd2 library.