summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-20 17:54:13 +0200
committerGitHub <noreply@github.com>2020-04-20 17:54:13 +0200
commit6a397d32ad4e771aab7d8e2da88e4b857258d126 (patch)
treef63c32f4bb1080db0788f3ffb550f21f954dd40c /src
parent9cd2062958d20a86fe234b960524160d31e18396 (diff)
parent3a9105e0d3bea5cc64610b7c0c3fb15f0e00bb9d (diff)
Merge pull request #284 from mglisse/wdiag
Simplify distance-to-diagonal in Wasserstein
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py27
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py7
2 files changed, 25 insertions, 9 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 35315939..efc851a0 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -15,6 +15,8 @@ try:
except ImportError:
print("POT (Python Optimal Transport) package is not installed. Try to run $ conda install -c conda-forge pot ; or $ pip install POT")
+
+# Currently unused, but Théo says it is likely to be used again.
def _proj_on_diag(X):
'''
:param X: (n x 2) array encoding the points of a persistent diagram.
@@ -24,7 +26,19 @@ def _proj_on_diag(X):
return np.array([Z , Z]).T
-def _build_dist_matrix(X, Y, order=2., internal_p=2.):
+def _dist_to_diag(X, internal_p):
+ '''
+ :param X: (n x 2) array encoding the points of a persistent diagram.
+ :param internal_p: Ground metric (i.e. norm L^p).
+ :returns: (n) array encoding the (respective orthogonal) distances of the points to the diagonal
+
+ .. note::
+ Assumes that the points are above the diagonal.
+ '''
+ return (X[:, 1] - X[:, 0]) * 2 ** (1.0 / internal_p - 1)
+
+
+def _build_dist_matrix(X, Y, order, internal_p):
'''
:param X: (n x 2) numpy.array encoding the (points of the) first diagram.
:param Y: (m x 2) numpy.array encoding the second diagram.
@@ -36,16 +50,12 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.):
and its orthogonal projection onto the diagonal.
note also that C[n, m] = 0 (it costs nothing to move from the diagonal to the diagonal).
'''
- Xdiag = _proj_on_diag(X)
- Ydiag = _proj_on_diag(Y)
+ Cxd = _dist_to_diag(X, internal_p)**order
+ Cdy = _dist_to_diag(Y, internal_p)**order
if np.isinf(internal_p):
C = sc.cdist(X,Y, metric='chebyshev')**order
- Cxd = np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order
- Cdy = np.linalg.norm(Y - Ydiag, ord=internal_p, axis=1)**order
else:
C = sc.cdist(X,Y, metric='minkowski', p=internal_p)**order
- Cxd = np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order
- Cdy = np.linalg.norm(Y - Ydiag, ord=internal_p, axis=1)**order
Cf = np.hstack((C, Cxd[:,None]))
Cdy = np.append(Cdy, 0)
@@ -61,8 +71,7 @@ def _perstot(X, order, internal_p):
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
:returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
'''
- Xdiag = _proj_on_diag(X)
- return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
+ return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 7e0d0f5f..1a4acc1d 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -8,6 +8,7 @@
- YYYY/MM Author: Description of the modification
"""
+from gudhi.wasserstein.wasserstein import _proj_on_diag
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
@@ -17,6 +18,12 @@ __author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
+def test_proj_on_diag():
+ dgm = np.array([[1., 1.], [1., 2.], [3., 5.]])
+ assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]])
+ empty = np.empty((0, 2))
+ assert np.array_equal(_proj_on_diag(empty), empty)
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])