diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-20 11:37:44 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-20 11:37:44 +0200 |
commit | 3a9105e0d3bea5cc64610b7c0c3fb15f0e00bb9d (patch) | |
tree | f63c32f4bb1080db0788f3ffb550f21f954dd40c /src/python | |
parent | 17aaa979e4cdfe5faed9b2750d452171de4b67e1 (diff) |
Reintroduce _proj_on_diag, with a unit test
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 11 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 7 |
2 files changed, 18 insertions, 0 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 5df66cf9..efc851a0 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -15,6 +15,17 @@ try: except ImportError: print("POT (Python Optimal Transport) package is not installed. Try to run $ conda install -c conda-forge pot ; or $ pip install POT") + +# Currently unused, but Théo says it is likely to be used again. +def _proj_on_diag(X): + ''' + :param X: (n x 2) array encoding the points of a persistent diagram. + :returns: (n x 2) array encoding the (respective orthogonal) projections of the points onto the diagonal + ''' + Z = (X[:,0] + X[:,1]) / 2. + return np.array([Z , Z]).T + + def _dist_to_diag(X, internal_p): ''' :param X: (n x 2) array encoding the points of a persistent diagram. diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 7e0d0f5f..1a4acc1d 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -8,6 +8,7 @@ - YYYY/MM Author: Description of the modification """ +from gudhi.wasserstein.wasserstein import _proj_on_diag from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -17,6 +18,12 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" +def test_proj_on_diag(): + dgm = np.array([[1., 1.], [1., 2.], [3., 5.]]) + assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]]) + empty = np.empty((0, 2)) + assert np.array_equal(_proj_on_diag(empty), empty) + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) |