summaryrefslogtreecommitdiff
path: root/src/python/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test')
-rwxr-xr-xsrc/python/test/test_representations.py20
-rwxr-xr-xsrc/python/test/test_simplex_tree.py19
2 files changed, 36 insertions, 3 deletions
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index e5c211a0..43c914f3 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -39,11 +39,11 @@ def test_multiple():
d2 = BottleneckDistance(epsilon=0.00001).fit_transform(l1)
d3 = pairwise_persistence_diagram_distances(l1, l1b, e=0.00001, n_jobs=4)
assert d1 == pytest.approx(d2)
- assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
+ assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
print(d1.shape, d2.shape)
- assert d1 == pytest.approx(d2, rel=.02)
+ assert d1 == pytest.approx(d2, rel=0.02)
def test_dummy_atol():
@@ -53,8 +53,22 @@ def test_dummy_atol():
for weighting_method in ["cloud", "iidproba"]:
for contrast in ["gaussian", "laplacian", "indicator"]:
- atol_vectoriser = Atol(quantiser=KMeans(n_clusters=1, random_state=202006), weighting_method=weighting_method, contrast=contrast)
+ atol_vectoriser = Atol(
+ quantiser=KMeans(n_clusters=1, random_state=202006),
+ weighting_method=weighting_method,
+ contrast=contrast,
+ )
atol_vectoriser.fit([a, b, c])
atol_vectoriser(a)
atol_vectoriser.transform(X=[a, b, c])
+
+from gudhi.representations.vector_methods import BettiCurve
+
+
+def test_infinity():
+ a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]])
+ c = BettiCurve(20, [0.0, 10.0])(a)
+ assert c[1] == 0
+ assert c[7] == 3
+ assert c[9] == 2
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index ac2b59c7..3b23fa0b 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -380,3 +380,22 @@ def test_reset_filtration():
assert st.filtration(simplex[0]) >= 2.
else:
assert st.filtration(simplex[0]) == 0.
+
+def test_boundaries_iterator():
+ st = SimplexTree()
+
+ assert st.insert([0, 1, 2, 3], filtration=1.0) == True
+ assert st.insert([1, 2, 3, 4], filtration=2.0) == True
+
+ assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)]
+ assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)]
+ assert list(st.get_boundaries([2])) == []
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([]))
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([0, 4])) # (0, 4) does not exist
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([6])) # (6) does not exist