diff options
Diffstat (limited to 'src/python/test')
-rwxr-xr-x | src/python/test/test_simplex_tree.py | 82 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_barycenter.py | 46 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 2 |
3 files changed, 129 insertions, 1 deletions
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index f7848379..70b26e97 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -9,6 +9,7 @@ """ from gudhi import SimplexTree +import pytest __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" @@ -250,6 +251,87 @@ def test_make_filtration_non_decreasing(): assert st.filtration([3, 4]) == 2.0 assert st.filtration([4, 5]) == 2.0 +def test_extend_filtration(): + + # Inserted simplex: + # 5 4 + # o o + # / \ / + # o o + # /2\ /3 + # o o + # 1 0 + + st = SimplexTree() + st.insert([0,2]) + st.insert([1,2]) + st.insert([0,3]) + st.insert([2,5]) + st.insert([3,4]) + st.insert([3,5]) + st.assign_filtration([0], 1.) + st.assign_filtration([1], 2.) + st.assign_filtration([2], 3.) + st.assign_filtration([3], 4.) + st.assign_filtration([4], 5.) + st.assign_filtration([5], 6.) + + assert list(st.get_filtration()) == [ + ([0, 2], 0.0), + ([1, 2], 0.0), + ([0, 3], 0.0), + ([3, 4], 0.0), + ([2, 5], 0.0), + ([3, 5], 0.0), + ([0], 1.0), + ([1], 2.0), + ([2], 3.0), + ([3], 4.0), + ([4], 5.0), + ([5], 6.0) + ] + + st.extend_filtration() + + assert list(st.get_filtration()) == [ + ([6], -3.0), + ([0], -2.0), + ([1], -1.8), + ([2], -1.6), + ([0, 2], -1.6), + ([1, 2], -1.6), + ([3], -1.4), + ([0, 3], -1.4), + ([4], -1.2), + ([3, 4], -1.2), + ([5], -1.0), + ([2, 5], -1.0), + ([3, 5], -1.0), + ([5, 6], 1.0), + ([4, 6], 1.2), + ([3, 6], 1.4), + ([3, 4, 6], 1.4), + ([3, 5, 6], 1.4), + ([2, 6], 1.6), + ([2, 5, 6], 1.6), + ([1, 6], 1.8), + ([1, 2, 6], 1.8), + ([0, 6], 2.0), + ([0, 2, 6], 2.0), + ([0, 3, 6], 2.0) + ] + + dgms = st.extended_persistence(min_persistence=-1.) + + assert dgms[0][0][1][0] == pytest.approx(2.) + assert dgms[0][0][1][1] == pytest.approx(3.) + assert dgms[1][0][1][0] == pytest.approx(5.) + assert dgms[1][0][1][1] == pytest.approx(4.) + assert dgms[2][0][1][0] == pytest.approx(1.) + assert dgms[2][0][1][1] == pytest.approx(6.) + assert dgms[3][0][1][0] == pytest.approx(6.) + assert dgms[3][0][1][1] == pytest.approx(1.) + def test_simplices_iterator(): st = SimplexTree() diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py new file mode 100755 index 00000000..f68c748e --- /dev/null +++ b/src/python/test/test_wasserstein_barycenter.py @@ -0,0 +1,46 @@ +from gudhi.wasserstein.barycenter import lagrangian_barycenter +import numpy as np + +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Theo Lacombe + + Copyright (C) 2019 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +__author__ = "Theo Lacombe" +__copyright__ = "Copyright (C) 2019 Inria" +__license__ = "MIT" + + +def test_lagrangian_barycenter(): + + dg1 = np.array([[0.2, 0.5]]) + dg2 = np.array([[0.2, 0.7]]) + dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]]) + dg4 = np.array([]) + dg5 = np.array([]) + dg6 = np.array([]) + res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]]) + + dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]]) + dg8 = np.array([[0., 4.], [4, 8]]) + + # error crit. + eps = 1e-7 + + + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps + assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2))) + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps + Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True) + assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps + assert np.abs(log["energy"] - 2) < eps + assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]])) + assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]])) + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps + assert lagrangian_barycenter(pdiagset = []) is None + diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 0d70e11a..7e0d0f5f 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -70,7 +70,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert np.array_equal(match , [[0, -1], [1, -1]]) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) - + def hera_wrap(delta): |