summaryrefslogtreecommitdiff
path: root/src/python/test/test_simplex_tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_simplex_tree.py')
-rwxr-xr-xsrc/python/test/test_simplex_tree.py365
1 files changed, 225 insertions, 140 deletions
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index 54bafed5..2ccbfbf5 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -249,6 +249,7 @@ def test_make_filtration_non_decreasing():
assert st.filtration([3, 4]) == 2.0
assert st.filtration([4, 5]) == 2.0
+
def test_extend_filtration():
# Inserted simplex:
@@ -257,86 +258,87 @@ def test_extend_filtration():
# / \ /
# o o
# /2\ /3
- # o o
- # 1 0
-
- st = SimplexTree()
- st.insert([0,2])
- st.insert([1,2])
- st.insert([0,3])
- st.insert([2,5])
- st.insert([3,4])
- st.insert([3,5])
- st.assign_filtration([0], 1.)
- st.assign_filtration([1], 2.)
- st.assign_filtration([2], 3.)
- st.assign_filtration([3], 4.)
- st.assign_filtration([4], 5.)
- st.assign_filtration([5], 6.)
-
- assert list(st.get_filtration()) == [
- ([0, 2], 0.0),
- ([1, 2], 0.0),
- ([0, 3], 0.0),
- ([3, 4], 0.0),
- ([2, 5], 0.0),
- ([3, 5], 0.0),
- ([0], 1.0),
- ([1], 2.0),
- ([2], 3.0),
- ([3], 4.0),
- ([4], 5.0),
- ([5], 6.0)
+ # o o
+ # 1 0
+
+ st = SimplexTree()
+ st.insert([0, 2])
+ st.insert([1, 2])
+ st.insert([0, 3])
+ st.insert([2, 5])
+ st.insert([3, 4])
+ st.insert([3, 5])
+ st.assign_filtration([0], 1.0)
+ st.assign_filtration([1], 2.0)
+ st.assign_filtration([2], 3.0)
+ st.assign_filtration([3], 4.0)
+ st.assign_filtration([4], 5.0)
+ st.assign_filtration([5], 6.0)
+
+ assert list(st.get_filtration()) == [
+ ([0, 2], 0.0),
+ ([1, 2], 0.0),
+ ([0, 3], 0.0),
+ ([3, 4], 0.0),
+ ([2, 5], 0.0),
+ ([3, 5], 0.0),
+ ([0], 1.0),
+ ([1], 2.0),
+ ([2], 3.0),
+ ([3], 4.0),
+ ([4], 5.0),
+ ([5], 6.0),
]
-
+
st.extend_filtration()
-
- assert list(st.get_filtration()) == [
- ([6], -3.0),
- ([0], -2.0),
- ([1], -1.8),
- ([2], -1.6),
- ([0, 2], -1.6),
- ([1, 2], -1.6),
- ([3], -1.4),
- ([0, 3], -1.4),
- ([4], -1.2),
- ([3, 4], -1.2),
- ([5], -1.0),
- ([2, 5], -1.0),
- ([3, 5], -1.0),
- ([5, 6], 1.0),
- ([4, 6], 1.2),
- ([3, 6], 1.4),
+
+ assert list(st.get_filtration()) == [
+ ([6], -3.0),
+ ([0], -2.0),
+ ([1], -1.8),
+ ([2], -1.6),
+ ([0, 2], -1.6),
+ ([1, 2], -1.6),
+ ([3], -1.4),
+ ([0, 3], -1.4),
+ ([4], -1.2),
+ ([3, 4], -1.2),
+ ([5], -1.0),
+ ([2, 5], -1.0),
+ ([3, 5], -1.0),
+ ([5, 6], 1.0),
+ ([4, 6], 1.2),
+ ([3, 6], 1.4),
([3, 4, 6], 1.4),
- ([3, 5, 6], 1.4),
- ([2, 6], 1.6),
- ([2, 5, 6], 1.6),
- ([1, 6], 1.8),
- ([1, 2, 6], 1.8),
- ([0, 6], 2.0),
- ([0, 2, 6], 2.0),
- ([0, 3, 6], 2.0)
+ ([3, 5, 6], 1.4),
+ ([2, 6], 1.6),
+ ([2, 5, 6], 1.6),
+ ([1, 6], 1.8),
+ ([1, 2, 6], 1.8),
+ ([0, 6], 2.0),
+ ([0, 2, 6], 2.0),
+ ([0, 3, 6], 2.0),
]
- dgms = st.extended_persistence(min_persistence=-1.)
+ dgms = st.extended_persistence(min_persistence=-1.0)
assert len(dgms) == 4
# Sort by (death-birth) descending - we are only interested in those with the longest life span
for idx in range(4):
- dgms[idx] = sorted(dgms[idx], key=lambda x:(-abs(x[1][0]-x[1][1])))
+ dgms[idx] = sorted(dgms[idx], key=lambda x: (-abs(x[1][0] - x[1][1])))
+
+ assert dgms[0][0][1][0] == pytest.approx(2.0)
+ assert dgms[0][0][1][1] == pytest.approx(3.0)
+ assert dgms[1][0][1][0] == pytest.approx(5.0)
+ assert dgms[1][0][1][1] == pytest.approx(4.0)
+ assert dgms[2][0][1][0] == pytest.approx(1.0)
+ assert dgms[2][0][1][1] == pytest.approx(6.0)
+ assert dgms[3][0][1][0] == pytest.approx(6.0)
+ assert dgms[3][0][1][1] == pytest.approx(1.0)
- assert dgms[0][0][1][0] == pytest.approx(2.)
- assert dgms[0][0][1][1] == pytest.approx(3.)
- assert dgms[1][0][1][0] == pytest.approx(5.)
- assert dgms[1][0][1][1] == pytest.approx(4.)
- assert dgms[2][0][1][0] == pytest.approx(1.)
- assert dgms[2][0][1][1] == pytest.approx(6.)
- assert dgms[3][0][1][0] == pytest.approx(6.)
- assert dgms[3][0][1][1] == pytest.approx(1.)
def test_simplices_iterator():
st = SimplexTree()
-
+
assert st.insert([0, 1, 2], filtration=4.0) == True
assert st.insert([2, 3, 4], filtration=2.0) == True
@@ -346,9 +348,10 @@ def test_simplices_iterator():
print("filtration is: ", simplex[1])
assert st.filtration(simplex[0]) == simplex[1]
+
def test_collapse_edges():
st = SimplexTree()
-
+
assert st.insert([0, 1], filtration=1.0) == True
assert st.insert([1, 2], filtration=1.0) == True
assert st.insert([2, 3], filtration=1.0) == True
@@ -360,31 +363,33 @@ def test_collapse_edges():
st.collapse_edges()
assert st.num_simplices() == 9
- assert st.find([0, 2]) == False # [1, 3] would be fine as well
+ assert st.find([0, 2]) == False # [1, 3] would be fine as well
for simplex in st.get_skeleton(0):
- assert simplex[1] == 1.
+ assert simplex[1] == 1.0
+
def test_reset_filtration():
st = SimplexTree()
-
- assert st.insert([0, 1, 2], 3.) == True
- assert st.insert([0, 3], 2.) == True
- assert st.insert([3, 4, 5], 3.) == True
- assert st.insert([0, 1, 6, 7], 4.) == True
+
+ assert st.insert([0, 1, 2], 3.0) == True
+ assert st.insert([0, 3], 2.0) == True
+ assert st.insert([3, 4, 5], 3.0) == True
+ assert st.insert([0, 1, 6, 7], 4.0) == True
# Guaranteed by construction
for simplex in st.get_simplices():
- assert st.filtration(simplex[0]) >= 2.
-
+ assert st.filtration(simplex[0]) >= 2.0
+
# dimension until 5 even if simplex tree is of dimension 3 to test the limits
for dimension in range(5, -1, -1):
- st.reset_filtration(0., dimension)
+ st.reset_filtration(0.0, dimension)
for simplex in st.get_skeleton(3):
print(simplex)
if len(simplex[0]) < (dimension) + 1:
- assert st.filtration(simplex[0]) >= 2.
+ assert st.filtration(simplex[0]) >= 2.0
else:
- assert st.filtration(simplex[0]) == 0.
+ assert st.filtration(simplex[0]) == 0.0
+
def test_boundaries_iterator():
st = SimplexTree()
@@ -400,16 +405,17 @@ def test_boundaries_iterator():
list(st.get_boundaries([]))
with pytest.raises(RuntimeError):
- list(st.get_boundaries([0, 4])) # (0, 4) does not exist
+ list(st.get_boundaries([0, 4])) # (0, 4) does not exist
with pytest.raises(RuntimeError):
- list(st.get_boundaries([6])) # (6) does not exist
+ list(st.get_boundaries([6])) # (6) does not exist
+
def test_persistence_intervals_in_dimension():
# Here is our triangulation of a 2-torus - taken from https://dioscuri-tda.org/Paris_TDA_Tutorial_2021.html
# 0-----3-----4-----0
# | \ | \ | \ | \ |
- # | \ | \ | \| \ |
+ # | \ | \ | \| \ |
# 1-----8-----7-----1
# | \ | \ | \ | \ |
# | \ | \ | \ | \ |
@@ -418,50 +424,52 @@ def test_persistence_intervals_in_dimension():
# | \ | \ | \ | \ |
# 0-----3-----4-----0
st = SimplexTree()
- st.insert([0,1,8])
- st.insert([0,3,8])
- st.insert([3,7,8])
- st.insert([3,4,7])
- st.insert([1,4,7])
- st.insert([0,1,4])
- st.insert([1,2,5])
- st.insert([1,5,8])
- st.insert([5,6,8])
- st.insert([6,7,8])
- st.insert([2,6,7])
- st.insert([1,2,7])
- st.insert([0,2,3])
- st.insert([2,3,5])
- st.insert([3,4,5])
- st.insert([4,5,6])
- st.insert([0,4,6])
- st.insert([0,2,6])
+ st.insert([0, 1, 8])
+ st.insert([0, 3, 8])
+ st.insert([3, 7, 8])
+ st.insert([3, 4, 7])
+ st.insert([1, 4, 7])
+ st.insert([0, 1, 4])
+ st.insert([1, 2, 5])
+ st.insert([1, 5, 8])
+ st.insert([5, 6, 8])
+ st.insert([6, 7, 8])
+ st.insert([2, 6, 7])
+ st.insert([1, 2, 7])
+ st.insert([0, 2, 3])
+ st.insert([2, 3, 5])
+ st.insert([3, 4, 5])
+ st.insert([4, 5, 6])
+ st.insert([0, 4, 6])
+ st.insert([0, 2, 6])
st.compute_persistence(persistence_dim_max=True)
-
+
H0 = st.persistence_intervals_in_dimension(0)
- assert np.array_equal(H0, np.array([[ 0., float("inf")]]))
+ assert np.array_equal(H0, np.array([[0.0, float("inf")]]))
H1 = st.persistence_intervals_in_dimension(1)
- assert np.array_equal(H1, np.array([[ 0., float("inf")], [ 0., float("inf")]]))
+ assert np.array_equal(H1, np.array([[0.0, float("inf")], [0.0, float("inf")]]))
H2 = st.persistence_intervals_in_dimension(2)
- assert np.array_equal(H2, np.array([[ 0., float("inf")]]))
+ assert np.array_equal(H2, np.array([[0.0, float("inf")]]))
# Test empty case
assert st.persistence_intervals_in_dimension(3).shape == (0, 2)
+
def test_equality_operator():
st1 = SimplexTree()
st2 = SimplexTree()
assert st1 == st2
- st1.insert([1,2,3], 4.)
+ st1.insert([1, 2, 3], 4.0)
assert st1 != st2
- st2.insert([1,2,3], 4.)
+ st2.insert([1, 2, 3], 4.0)
assert st1 == st2
+
def test_simplex_tree_deep_copy():
st = SimplexTree()
- st.insert([1, 2, 3], 0.)
+ st.insert([1, 2, 3], 0.0)
# compute persistence only on the original
st.compute_persistence()
@@ -480,14 +488,15 @@ def test_simplex_tree_deep_copy():
for a_splx in a_filt_list:
assert a_splx in st_filt_list
-
+
# test double free
del st
del st_copy
+
def test_simplex_tree_deep_copy_constructor():
st = SimplexTree()
- st.insert([1, 2, 3], 0.)
+ st.insert([1, 2, 3], 0.0)
# compute persistence only on the original
st.compute_persistence()
@@ -506,56 +515,132 @@ def test_simplex_tree_deep_copy_constructor():
for a_splx in a_filt_list:
assert a_splx in st_filt_list
-
+
# test double free
del st
del st_copy
+
def test_simplex_tree_constructor_exception():
with pytest.raises(TypeError):
- st = SimplexTree(other = "Construction from a string shall raise an exception")
+ st = SimplexTree(other="Construction from a string shall raise an exception")
+
+
+def test_create_from_array():
+ a = np.array([[1, 4, 13, 6], [4, 3, 11, 5], [13, 11, 10, 12], [6, 5, 12, 2]])
+ st = SimplexTree.create_from_array(a, max_filtration=5.0)
+ assert list(st.get_filtration()) == [([0], 1.0), ([3], 2.0), ([1], 3.0), ([0, 1], 4.0), ([1, 3], 5.0)]
+
+
+def test_insert_edges_from_coo_matrix():
+ try:
+ from scipy.sparse import coo_matrix
+ from scipy.spatial import cKDTree
+ except ImportError:
+ print("Skipping, no SciPy")
+ return
+
+ st = SimplexTree()
+ st.insert([1, 2, 7], 7)
+ row = np.array([2, 5, 3])
+ col = np.array([1, 4, 6])
+ dat = np.array([1, 2, 3])
+ edges = coo_matrix((dat, (row, col)))
+ st.insert_edges_from_coo_matrix(edges)
+ assert list(st.get_filtration()) == [
+ ([1], 1.0),
+ ([2], 1.0),
+ ([1, 2], 1.0),
+ ([4], 2.0),
+ ([5], 2.0),
+ ([4, 5], 2.0),
+ ([3], 3.0),
+ ([6], 3.0),
+ ([3, 6], 3.0),
+ ([7], 7.0),
+ ([1, 7], 7.0),
+ ([2, 7], 7.0),
+ ([1, 2, 7], 7.0),
+ ]
+
+ pts = np.random.rand(100, 2)
+ tree = cKDTree(pts)
+ edges = tree.sparse_distance_matrix(tree, max_distance=0.15, output_type="coo_matrix")
+ st = SimplexTree()
+ st.insert_edges_from_coo_matrix(edges)
+ assert 100 < st.num_simplices() < 1000
+
+
+def test_insert_batch():
+ st = SimplexTree()
+ # vertices
+ st.insert_batch(np.array([[6, 1, 5]]), np.array([-5.0, 2.0, -3.0]))
+ # triangles
+ st.insert_batch(np.array([[2, 10], [5, 0], [6, 11]]), np.array([4.0, 0.0]))
+ # edges
+ st.insert_batch(np.array([[1, 5], [2, 5]]), np.array([1.0, 3.0]))
+
+ assert list(st.get_filtration()) == [
+ ([6], -5.0),
+ ([5], -3.0),
+ ([0], 0.0),
+ ([10], 0.0),
+ ([0, 10], 0.0),
+ ([11], 0.0),
+ ([0, 11], 0.0),
+ ([10, 11], 0.0),
+ ([0, 10, 11], 0.0),
+ ([1], 1.0),
+ ([2], 1.0),
+ ([1, 2], 1.0),
+ ([2, 5], 4.0),
+ ([2, 6], 4.0),
+ ([5, 6], 4.0),
+ ([2, 5, 6], 4.0),
+ ]
+
def test_expansion_with_blocker():
- st=SimplexTree()
- st.insert([0,1],0)
- st.insert([0,2],1)
- st.insert([0,3],2)
- st.insert([1,2],3)
- st.insert([1,3],4)
- st.insert([2,3],5)
- st.insert([2,4],6)
- st.insert([3,6],7)
- st.insert([4,5],8)
- st.insert([4,6],9)
- st.insert([5,6],10)
- st.insert([6],10)
+ st = SimplexTree()
+ st.insert([0, 1], 0)
+ st.insert([0, 2], 1)
+ st.insert([0, 3], 2)
+ st.insert([1, 2], 3)
+ st.insert([1, 3], 4)
+ st.insert([2, 3], 5)
+ st.insert([2, 4], 6)
+ st.insert([3, 6], 7)
+ st.insert([4, 5], 8)
+ st.insert([4, 6], 9)
+ st.insert([5, 6], 10)
+ st.insert([6], 10)
def blocker(simplex):
try:
# Block all simplices that contain vertex 6
simplex.index(6)
- print(simplex, ' is blocked')
+ print(simplex, " is blocked")
return True
except ValueError:
- print(simplex, ' is accepted')
- st.assign_filtration(simplex, st.filtration(simplex) + 1.)
+ print(simplex, " is accepted")
+ st.assign_filtration(simplex, st.filtration(simplex) + 1.0)
return False
st.expansion_with_blocker(2, blocker)
assert st.num_simplices() == 22
assert st.dimension() == 2
- assert st.find([4,5,6]) == False
- assert st.filtration([0,1,2]) == 4.
- assert st.filtration([0,1,3]) == 5.
- assert st.filtration([0,2,3]) == 6.
- assert st.filtration([1,2,3]) == 6.
+ assert st.find([4, 5, 6]) == False
+ assert st.filtration([0, 1, 2]) == 4.0
+ assert st.filtration([0, 1, 3]) == 5.0
+ assert st.filtration([0, 2, 3]) == 6.0
+ assert st.filtration([1, 2, 3]) == 6.0
st.expansion_with_blocker(3, blocker)
assert st.num_simplices() == 23
assert st.dimension() == 3
- assert st.find([4,5,6]) == False
- assert st.filtration([0,1,2]) == 4.
- assert st.filtration([0,1,3]) == 5.
- assert st.filtration([0,2,3]) == 6.
- assert st.filtration([1,2,3]) == 6.
- assert st.filtration([0,1,2,3]) == 7.
+ assert st.find([4, 5, 6]) == False
+ assert st.filtration([0, 1, 2]) == 4.0
+ assert st.filtration([0, 1, 3]) == 5.0
+ assert st.filtration([0, 2, 3]) == 6.0
+ assert st.filtration([1, 2, 3]) == 6.0
+ assert st.filtration([0, 1, 2, 3]) == 7.0