diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-23 14:51:53 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-23 14:51:53 +0200 |
commit | 031e6879f94503e5250c005f8cb71e581799d2f3 (patch) | |
tree | 500c55f4b47a33b933f03ffad16b8c26df121830 /src/python/test/test_wasserstein_distance.py | |
parent | 65f6ca41a9cd6574a0ca8fa9b781c787064fe4ed (diff) | |
parent | e3f276ab5b7503ba7ce278fffbf73ebe66d6351c (diff) |
Merge remote-tracking branch 'origin/master' into compute_persistence
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 0d70e11a..1a4acc1d 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -8,6 +8,7 @@ - YYYY/MM Author: Description of the modification """ +from gudhi.wasserstein.wasserstein import _proj_on_diag from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -17,6 +18,12 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" +def test_proj_on_diag(): + dgm = np.array([[1., 1.], [1., 2.], [3., 5.]]) + assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]]) + empty = np.empty((0, 2)) + assert np.array_equal(_proj_on_diag(empty), empty) + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -70,7 +77,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert np.array_equal(match , [[0, -1], [1, -1]]) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) - + def hera_wrap(delta): |