diff options
Diffstat (limited to 'src/python/test')
-rwxr-xr-x | src/python/test/test_alpha_complex.py | 27 | ||||
-rw-r--r-- | src/python/test/test_persistence_graphical_tools.py | 121 | ||||
-rwxr-xr-x | src/python/test/test_simplex_tree.py | 131 |
3 files changed, 267 insertions, 12 deletions
diff --git a/src/python/test/test_alpha_complex.py b/src/python/test/test_alpha_complex.py index f15284f3..f81e6137 100755 --- a/src/python/test/test_alpha_complex.py +++ b/src/python/test/test_alpha_complex.py @@ -286,3 +286,30 @@ def _weighted_doc_example(precision): def test_weighted_doc_example(): for precision in ['fast', 'safe', 'exact']: _weighted_doc_example(precision) + +def test_float_relative_precision(): + assert AlphaComplex.get_float_relative_precision() == 1e-5 + # Must be > 0. + with pytest.raises(ValueError): + AlphaComplex.set_float_relative_precision(0.) + # Must be < 1. + with pytest.raises(ValueError): + AlphaComplex.set_float_relative_precision(1.) + + points = [[1, 1], [7, 0], [4, 6], [9, 6], [0, 14], [2, 19], [9, 17]] + st = AlphaComplex(points=points).create_simplex_tree() + filtrations = list(st.get_filtration()) + + # Get a better precision + AlphaComplex.set_float_relative_precision(1e-15) + assert AlphaComplex.get_float_relative_precision() == 1e-15 + + st = AlphaComplex(points=points).create_simplex_tree() + filtrations_better_resolution = list(st.get_filtration()) + + assert len(filtrations) == len(filtrations_better_resolution) + for idx in range(len(filtrations)): + # check simplex is the same + assert filtrations[idx][0] == filtrations_better_resolution[idx][0] + # check filtration is about the same with a relative precision of the worst case + assert filtrations[idx][1] == pytest.approx(filtrations_better_resolution[idx][1], rel=1e-5) diff --git a/src/python/test/test_persistence_graphical_tools.py b/src/python/test/test_persistence_graphical_tools.py new file mode 100644 index 00000000..c19836b7 --- /dev/null +++ b/src/python/test/test_persistence_graphical_tools.py @@ -0,0 +1,121 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2021 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +import gudhi as gd +import numpy as np +import matplotlib as plt +import pytest + + +def test_array_handler(): + diags = np.array([[1, 2], [3, 4], [5, 6]], float) + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + np.testing.assert_array_equal(arr_diags[idx][1], diags[idx]) + + diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] + arr_diags = gd.persistence_graphical_tools._array_handler(diags) + for idx in range(len(diags)): + assert arr_diags[idx][0] == 0 + assert arr_diags[idx][1] == diags[idx] + + diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))] + assert gd.persistence_graphical_tools._array_handler(diags) == diags + + +def test_min_birth_max_death(): + diags = [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (0.0, 1.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (0.0, 5.0) + + +def test_limit_min_birth_max_death(): + diags = [ + (0, (2.0, float("inf"))), + (0, (2.0, float("inf"))), + ] + assert gd.persistence_graphical_tools.__min_birth_max_death(diags) == (2.0, 3.0) + assert gd.persistence_graphical_tools.__min_birth_max_death(diags, band=4.0) == (2.0, 6.0) + + +def test_limit_to_max_intervals(): + diags = [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + (0, (0.0, 0.118398)), + (0, (0.118398, 1.0)), + (0, (0.0, 0.117908)), + (0, (0.0, 0.112307)), + (0, (0.0, 0.107535)), + (0, (0.0, 0.106382)), + ] + # check no warnings if max_intervals equals to the diagrams number + with pytest.warns(None) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 10, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + # check diagrams are not sorted + assert truncated_diags == diags + assert len(record) == 0 + + # check warning if max_intervals lower than the diagrams number + with pytest.warns(UserWarning) as record: + truncated_diags = gd.persistence_graphical_tools._limit_to_max_intervals( + diags, 5, key=lambda life_time: life_time[1][1] - life_time[1][0] + ) + # check diagrams are truncated and sorted by life time + assert truncated_diags == [ + (0, (0.0, float("inf"))), + (0, (0.0983494, float("inf"))), + (0, (0.118398, 1.0)), + (0, (0.0, 0.122545)), + (0, (0.0, 0.12047)), + ] + assert len(record) == 1 + + +def _limit_plot_persistence(function): + pplot = function(persistence=[]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))]) + assert isinstance(pplot, plt.axes.SubplotBase) + pplot = function(persistence=[(0, float("inf"))], legend=True) + assert isinstance(pplot, plt.axes.SubplotBase) + + +def test_limit_plot_persistence(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _limit_plot_persistence(function) + + +def _non_existing_persistence_file(function): + with pytest.raises(FileNotFoundError): + function(persistence_file="pouetpouettralala.toubiloubabdou") + + +def test_non_existing_persistence_file(): + for function in [gd.plot_persistence_barcode, gd.plot_persistence_diagram, gd.plot_persistence_density]: + _non_existing_persistence_file(function) diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 31c46213..688f4fd6 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -8,7 +8,7 @@ - YYYY/MM Author: Description of the modification """ -from gudhi import SimplexTree, __GUDHI_USE_EIGEN3 +from gudhi import SimplexTree import numpy as np import pytest @@ -354,16 +354,11 @@ def test_collapse_edges(): assert st.num_simplices() == 10 - if __GUDHI_USE_EIGEN3: - st.collapse_edges() - assert st.num_simplices() == 9 - assert st.find([1, 3]) == False - for simplex in st.get_skeleton(0): - assert simplex[1] == 1. - else: - # If no Eigen3, collapse_edges throws an exception - with pytest.raises(RuntimeError): - st.collapse_edges() + st.collapse_edges() + assert st.num_simplices() == 9 + assert st.find([0, 2]) == False # [1, 3] would be fine as well + for simplex in st.get_skeleton(0): + assert simplex[1] == 1. def test_reset_filtration(): st = SimplexTree() @@ -447,4 +442,116 @@ def test_persistence_intervals_in_dimension(): assert np.array_equal(H2, np.array([[ 0., float("inf")]])) # Test empty case assert st.persistence_intervals_in_dimension(3).shape == (0, 2) -
\ No newline at end of file + +def test_equality_operator(): + st1 = SimplexTree() + st2 = SimplexTree() + + assert st1 == st2 + + st1.insert([1,2,3], 4.) + assert st1 != st2 + + st2.insert([1,2,3], 4.) + assert st1 == st2 + +def test_simplex_tree_deep_copy(): + st = SimplexTree() + st.insert([1, 2, 3], 0.) + # compute persistence only on the original + st.compute_persistence() + + st_copy = st.copy() + assert st_copy == st + st_filt_list = list(st.get_filtration()) + + # check persistence is not copied + assert st.__is_persistence_defined() == True + assert st_copy.__is_persistence_defined() == False + + # remove something in the copy and check the copy is included in the original + st_copy.remove_maximal_simplex([1, 2, 3]) + a_filt_list = list(st_copy.get_filtration()) + assert len(a_filt_list) < len(st_filt_list) + + for a_splx in a_filt_list: + assert a_splx in st_filt_list + + # test double free + del st + del st_copy + +def test_simplex_tree_deep_copy_constructor(): + st = SimplexTree() + st.insert([1, 2, 3], 0.) + # compute persistence only on the original + st.compute_persistence() + + st_copy = SimplexTree(st) + assert st_copy == st + st_filt_list = list(st.get_filtration()) + + # check persistence is not copied + assert st.__is_persistence_defined() == True + assert st_copy.__is_persistence_defined() == False + + # remove something in the copy and check the copy is included in the original + st_copy.remove_maximal_simplex([1, 2, 3]) + a_filt_list = list(st_copy.get_filtration()) + assert len(a_filt_list) < len(st_filt_list) + + for a_splx in a_filt_list: + assert a_splx in st_filt_list + + # test double free + del st + del st_copy + +def test_simplex_tree_constructor_exception(): + with pytest.raises(TypeError): + st = SimplexTree(other = "Construction from a string shall raise an exception") + +def test_expansion_with_blocker(): + st=SimplexTree() + st.insert([0,1],0) + st.insert([0,2],1) + st.insert([0,3],2) + st.insert([1,2],3) + st.insert([1,3],4) + st.insert([2,3],5) + st.insert([2,4],6) + st.insert([3,6],7) + st.insert([4,5],8) + st.insert([4,6],9) + st.insert([5,6],10) + st.insert([6],10) + + def blocker(simplex): + try: + # Block all simplices that countains vertex 6 + simplex.index(6) + print(simplex, ' is blocked') + return True + except ValueError: + print(simplex, ' is accepted') + st.assign_filtration(simplex, st.filtration(simplex) + 1.) + return False + + st.expansion_with_blocker(2, blocker) + assert st.num_simplices() == 22 + assert st.dimension() == 2 + assert st.find([4,5,6]) == False + assert st.filtration([0,1,2]) == 4. + assert st.filtration([0,1,3]) == 5. + assert st.filtration([0,2,3]) == 6. + assert st.filtration([1,2,3]) == 6. + + st.expansion_with_blocker(3, blocker) + assert st.num_simplices() == 23 + assert st.dimension() == 3 + assert st.find([4,5,6]) == False + assert st.filtration([0,1,2]) == 4. + assert st.filtration([0,1,3]) == 5. + assert st.filtration([0,2,3]) == 6. + assert st.filtration([1,2,3]) == 6. + assert st.filtration([0,1,2,3]) == 7. |