summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test')
-rwxr-xr-xsrc/python/test/test_cubical_complex.py10
-rwxr-xr-xsrc/python/test/test_simplex_generators.py64
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py38
3 files changed, 108 insertions, 4 deletions
diff --git a/src/python/test/test_cubical_complex.py b/src/python/test/test_cubical_complex.py
index fce4875c..5c59db8f 100755
--- a/src/python/test/test_cubical_complex.py
+++ b/src/python/test/test_cubical_complex.py
@@ -147,3 +147,13 @@ def test_connected_sublevel_sets():
periodic_dimensions = periodic_dimensions)
assert cub.persistence() == [(0, (2.0, float("inf")))]
assert cub.betti_numbers() == [1, 0, 0]
+
+def test_cubical_generators():
+ cub = CubicalComplex(top_dimensional_cells = [[0, 0, 0], [0, 1, 0], [0, 0, 0]])
+ cub.persistence()
+ g = cub.cofaces_of_persistence_pairs()
+ assert len(g[0]) == 2
+ assert len(g[1]) == 1
+ assert np.array_equal(g[0][0], np.empty(shape=[0,2]))
+ assert np.array_equal(g[0][1], np.array([[7, 4]]))
+ assert np.array_equal(g[1][0], np.array([8]))
diff --git a/src/python/test/test_simplex_generators.py b/src/python/test/test_simplex_generators.py
new file mode 100755
index 00000000..8a9b4844
--- /dev/null
+++ b/src/python/test/test_simplex_generators.py
@@ -0,0 +1,64 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Marc Glisse
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+import gudhi
+import numpy as np
+
+
+def test_flag_generators():
+ pts = np.array([[0, 0], [0, 1.01], [1, 0], [1.02, 1.03], [100, 0], [100, 3.01], [103, 0], [103.02, 3.03]])
+ r = gudhi.RipsComplex(pts, max_edge_length=4)
+ st = r.create_simplex_tree(max_dimension=50)
+ st.persistence()
+ g = st.flag_persistence_generators()
+ assert np.array_equal(g[0], [[2, 2, 0], [1, 1, 0], [3, 3, 1], [6, 6, 4], [5, 5, 4], [7, 7, 5]])
+ assert len(g[1]) == 1
+ assert np.array_equal(g[1][0], [[3, 2, 2, 1]])
+ assert np.array_equal(g[2], [0, 4])
+ assert len(g[3]) == 1
+ assert np.array_equal(g[3][0], [[7, 6]])
+ # Compare trivial cases (where the simplex is the generator) with persistence_pairs.
+ # This still makes assumptions on the order of vertices in a simplex and could be more robust.
+ pairs = st.persistence_pairs()
+ assert {tuple(i) for i in g[0]} == {(i[0][0],) + tuple(i[1]) for i in pairs if len(i[0]) == 1 and len(i[1]) != 0}
+ assert {(i[0], i[1]) for i in g[1][0]} == {tuple(i[0]) for i in pairs if len(i[0]) == 2 and len(i[1]) != 0}
+ assert set(g[2]) == {i[0][0] for i in pairs if len(i[0]) == 1 and len(i[1]) == 0}
+ assert {(i[0], i[1]) for i in g[3][0]} == {tuple(i[0]) for i in pairs if len(i[0]) == 2 and len(i[1]) == 0}
+
+
+def test_lower_star_generators():
+ st = gudhi.SimplexTree()
+ st.insert([0, 1, 2], -10)
+ st.insert([0, 3], -10)
+ st.insert([1, 3], -10)
+ st.assign_filtration([2], -1)
+ st.assign_filtration([3], 0)
+ st.assign_filtration([0], 1)
+ st.assign_filtration([1], 2)
+ st.make_filtration_non_decreasing()
+ st.persistence(min_persistence=-1)
+ g = st.lower_star_persistence_generators()
+ assert len(g[0]) == 2
+ assert np.array_equal(g[0][0], [[0, 0], [3, 0], [1, 1]])
+ assert np.array_equal(g[0][1], [[1, 1]])
+ assert len(g[1]) == 2
+ assert np.array_equal(g[1][0], [2])
+ assert np.array_equal(g[1][1], [1])
+
+
+def test_empty():
+ st = gudhi.SimplexTree()
+ st.persistence()
+ assert st.lower_star_persistence_generators() == ([], [])
+ g = st.flag_persistence_generators()
+ assert np.array_equal(g[0], np.empty((0, 3)))
+ assert g[1] == []
+ assert np.array_equal(g[2], [])
+ assert g[3] == []
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 1a4acc1d..90d26809 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -80,14 +80,44 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
-def hera_wrap(delta):
+def hera_wrap(**extra):
def fun(*kargs,**kwargs):
- return hera(*kargs,**kwargs,delta=delta)
+ return hera(*kargs,**kwargs,**extra)
+ return fun
+
+def pot_wrap(**extra):
+ def fun(*kargs,**kwargs):
+ return pot(*kargs,**kwargs,**extra)
return fun
def test_wasserstein_distance_pot():
_basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
+ _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+
+def test_wasserstein_distance_grad():
+ import torch
+
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])