summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-20 18:41:59 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-20 18:41:59 +0200
commit0393fdd3da2b5e403757c0f3418919c81ccbdd76 (patch)
tree008b752c6069b165efee50cb928adc8267343101 /src/python/test/test_wasserstein_distance.py
parent1086b8cad7c1ea2a02742dfc44aef036a674f5d3 (diff)
parent93cd1240ef65d8883ec624e6e393c09969bf4d6f (diff)
Merge remote-tracking branch 'origin/master' into wass-autodiff
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index c6d6b346..6bfcb2ee 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -8,6 +8,7 @@
- YYYY/MM Author: Description of the modification
"""
+from gudhi.wasserstein.wasserstein import _proj_on_diag
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
@@ -17,6 +18,12 @@ __author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
+def test_proj_on_diag():
+ dgm = np.array([[1., 1.], [1., 2.], [3., 5.]])
+ assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]])
+ empty = np.empty((0, 2))
+ assert np.array_equal(_proj_on_diag(empty), empty)
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])