summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-14 19:49:49 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-14 19:49:49 +0200
commit3f1e4bf5f139afe887ae501f20c5d3f40b5a6f79 (patch)
tree936aa887b1e7a0e23854262f17a5d21cf2f604c1 /src/python/test
parent9518287cfa2a62948ede2e7d17d5c9f29092e0f4 (diff)
parent6d02ca0e077cc9750275abdfc024429cec0ba5a5 (diff)
Merge remote-tracking branch 'origin/master' into dtm
Diffstat (limited to 'src/python/test')
-rwxr-xr-xsrc/python/test/test_simplex_tree.py82
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py46
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py2
3 files changed, 129 insertions, 1 deletions
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index f7848379..70b26e97 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -9,6 +9,7 @@
"""
from gudhi import SimplexTree
+import pytest
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -250,6 +251,87 @@ def test_make_filtration_non_decreasing():
assert st.filtration([3, 4]) == 2.0
assert st.filtration([4, 5]) == 2.0
+def test_extend_filtration():
+
+ # Inserted simplex:
+ # 5 4
+ # o o
+ # / \ /
+ # o o
+ # /2\ /3
+ # o o
+ # 1 0
+
+ st = SimplexTree()
+ st.insert([0,2])
+ st.insert([1,2])
+ st.insert([0,3])
+ st.insert([2,5])
+ st.insert([3,4])
+ st.insert([3,5])
+ st.assign_filtration([0], 1.)
+ st.assign_filtration([1], 2.)
+ st.assign_filtration([2], 3.)
+ st.assign_filtration([3], 4.)
+ st.assign_filtration([4], 5.)
+ st.assign_filtration([5], 6.)
+
+ assert list(st.get_filtration()) == [
+ ([0, 2], 0.0),
+ ([1, 2], 0.0),
+ ([0, 3], 0.0),
+ ([3, 4], 0.0),
+ ([2, 5], 0.0),
+ ([3, 5], 0.0),
+ ([0], 1.0),
+ ([1], 2.0),
+ ([2], 3.0),
+ ([3], 4.0),
+ ([4], 5.0),
+ ([5], 6.0)
+ ]
+
+ st.extend_filtration()
+
+ assert list(st.get_filtration()) == [
+ ([6], -3.0),
+ ([0], -2.0),
+ ([1], -1.8),
+ ([2], -1.6),
+ ([0, 2], -1.6),
+ ([1, 2], -1.6),
+ ([3], -1.4),
+ ([0, 3], -1.4),
+ ([4], -1.2),
+ ([3, 4], -1.2),
+ ([5], -1.0),
+ ([2, 5], -1.0),
+ ([3, 5], -1.0),
+ ([5, 6], 1.0),
+ ([4, 6], 1.2),
+ ([3, 6], 1.4),
+ ([3, 4, 6], 1.4),
+ ([3, 5, 6], 1.4),
+ ([2, 6], 1.6),
+ ([2, 5, 6], 1.6),
+ ([1, 6], 1.8),
+ ([1, 2, 6], 1.8),
+ ([0, 6], 2.0),
+ ([0, 2, 6], 2.0),
+ ([0, 3, 6], 2.0)
+ ]
+
+ dgms = st.extended_persistence(min_persistence=-1.)
+
+ assert dgms[0][0][1][0] == pytest.approx(2.)
+ assert dgms[0][0][1][1] == pytest.approx(3.)
+ assert dgms[1][0][1][0] == pytest.approx(5.)
+ assert dgms[1][0][1][1] == pytest.approx(4.)
+ assert dgms[2][0][1][0] == pytest.approx(1.)
+ assert dgms[2][0][1][1] == pytest.approx(6.)
+ assert dgms[3][0][1][0] == pytest.approx(6.)
+ assert dgms[3][0][1][1] == pytest.approx(1.)
+
def test_simplices_iterator():
st = SimplexTree()
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
new file mode 100755
index 00000000..f68c748e
--- /dev/null
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -0,0 +1,46 @@
+from gudhi.wasserstein.barycenter import lagrangian_barycenter
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_lagrangian_barycenter():
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ dg5 = np.array([])
+ dg6 = np.array([])
+ res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
+
+ dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+ dg8 = np.array([[0., 4.], [4, 8]])
+
+ # error crit.
+ eps = 1e-7
+
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
+ assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
+ Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
+ assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
+ assert np.abs(log["energy"] - 2) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps
+ assert lagrangian_barycenter(pdiagset = []) is None
+
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 0d70e11a..7e0d0f5f 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -70,7 +70,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert np.array_equal(match , [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
-
+
def hera_wrap(delta):