diff options
author | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-04-28 13:48:45 -0400 |
---|---|---|
committer | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-04-28 13:48:45 -0400 |
commit | 4923f2bd8a18d2f66288f39c08309cb7cafa5627 (patch) | |
tree | 0f9572654e52fc0b0bc7994f07aee1a874c2a45a /src/python/test/test_wasserstein_distance.py | |
parent | 39b6731486838b8f2e608e5b5738c12e1c83266f (diff) | |
parent | 0fb22e4c499b665ad505e5d9d2c325f7561f69c4 (diff) |
fix conflict
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 0d70e11a..1a4acc1d 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -8,6 +8,7 @@ - YYYY/MM Author: Description of the modification """ +from gudhi.wasserstein.wasserstein import _proj_on_diag from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -17,6 +18,12 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" +def test_proj_on_diag(): + dgm = np.array([[1., 1.], [1., 2.], [3., 5.]]) + assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]]) + empty = np.empty((0, 2)) + assert np.array_equal(_proj_on_diag(empty), empty) + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -70,7 +77,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert np.array_equal(match , [[0, -1], [1, -1]]) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) - + def hera_wrap(delta): |