summaryrefslogtreecommitdiff
path: root/src/cython/test/Simplex_tree_UT.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/cython/test/Simplex_tree_UT.py')
-rwxr-xr-xsrc/cython/test/Simplex_tree_UT.py106
1 files changed, 57 insertions, 49 deletions
diff --git a/src/cython/test/Simplex_tree_UT.py b/src/cython/test/Simplex_tree_UT.py
index b3ded581..392c6d71 100755
--- a/src/cython/test/Simplex_tree_UT.py
+++ b/src/cython/test/Simplex_tree_UT.py
@@ -2,56 +2,64 @@ import unittest
import gudhi
+
class TestSimplexTree(unittest.TestCase):
- def test_insertion(self):
- st = gudhi.SimplexTree()
-
- # insert test
- self.assertTrue(st.insert([0,1]))
- self.assertTrue(st.insert([0,1,2], filtration=4.0))
- self.assertEqual(st.num_simplices(), 7)
- self.assertEqual(st.num_vertices(), 3)
-
- # find test
- self.assertTrue(st.find([0,1,2]))
- self.assertTrue(st.find([0,1]))
- self.assertTrue(st.find([0,2]))
- self.assertTrue(st.find([0]))
- self.assertTrue(st.find([1]))
- self.assertTrue(st.find([2]))
- self.assertFalse(st.find([3]))
- self.assertFalse(st.find([0,3]))
- self.assertFalse(st.find([1,3]))
- self.assertFalse(st.find([2,3]))
-
- # filtration test
- st.set_filtration(5.0)
- st.initialize_filtration()
- self.assertEqual(st.get_filtration(), 5.0)
- self.assertEqual(st.filtration([0,1,2]), 4.0)
- self.assertEqual(st.filtration([0,2]), 4.0)
- self.assertEqual(st.filtration([1,2]), 4.0)
- self.assertEqual(st.filtration([2]), 4.0)
- self.assertEqual(st.filtration([0,1]), 0.0)
- self.assertEqual(st.filtration([0]), 0.0)
- self.assertEqual(st.filtration([1]), 0.0)
-
- # skeleton_tree test
- self.assertEqual(st.get_skeleton_tree(2), [([0, 1, 2], 4.0), ([0, 1], 0.0), ([0, 2], 4.0), ([0], 0.0), ([1, 2], 4.0), ([1], 0.0), ([2], 4.0)])
- self.assertEqual(st.get_skeleton_tree(1), [([0, 1], 0.0), ([0, 2], 4.0), ([0], 0.0), ([1, 2], 4.0), ([1], 0.0), ([2], 4.0)])
- self.assertEqual(st.get_skeleton_tree(0), [([0], 0.0), ([1], 0.0), ([2], 4.0)])
-
- # remove_maximal_simplex test
- self.assertEqual(st.get_coface_tree([0,1,2], 1), [])
- st.remove_maximal_simplex([0,1,2])
- self.assertEqual(st.get_skeleton_tree(2), [([0, 1], 0.0), ([0, 2], 4.0), ([0], 0.0), ([1, 2], 4.0), ([1], 0.0), ([2], 4.0)])
- self.assertFalse(st.find([0,1,2]))
- self.assertTrue(st.find([0,1]))
- self.assertTrue(st.find([0,2]))
- self.assertTrue(st.find([0]))
- self.assertTrue(st.find([1]))
- self.assertTrue(st.find([2]))
+ def test_insertion(self):
+ st = gudhi.SimplexTree()
+
+ # insert test
+ self.assertTrue(st.insert([0, 1]))
+ self.assertTrue(st.insert([0, 1, 2], filtration=4.0))
+ self.assertEqual(st.num_simplices(), 7)
+ self.assertEqual(st.num_vertices(), 3)
+
+ # find test
+ self.assertTrue(st.find([0, 1, 2]))
+ self.assertTrue(st.find([0, 1]))
+ self.assertTrue(st.find([0, 2]))
+ self.assertTrue(st.find([0]))
+ self.assertTrue(st.find([1]))
+ self.assertTrue(st.find([2]))
+ self.assertFalse(st.find([3]))
+ self.assertFalse(st.find([0, 3]))
+ self.assertFalse(st.find([1, 3]))
+ self.assertFalse(st.find([2, 3]))
+
+ # filtration test
+ st.set_filtration(5.0)
+ st.initialize_filtration()
+ self.assertEqual(st.get_filtration(), 5.0)
+ self.assertEqual(st.filtration([0, 1, 2]), 4.0)
+ self.assertEqual(st.filtration([0, 2]), 4.0)
+ self.assertEqual(st.filtration([1, 2]), 4.0)
+ self.assertEqual(st.filtration([2]), 4.0)
+ self.assertEqual(st.filtration([0, 1]), 0.0)
+ self.assertEqual(st.filtration([0]), 0.0)
+ self.assertEqual(st.filtration([1]), 0.0)
+
+ # skeleton_tree test
+ self.assertEqual(st.get_skeleton_tree(2),
+ [([0, 1, 2], 4.0), ([0, 1], 0.0), ([0, 2], 4.0),
+ ([0], 0.0), ([1, 2], 4.0), ([1], 0.0), ([2], 4.0)])
+ self.assertEqual(st.get_skeleton_tree(1),
+ [([0, 1], 0.0), ([0, 2], 4.0), ([0], 0.0),
+ ([1, 2], 4.0), ([1], 0.0), ([2], 4.0)])
+ self.assertEqual(st.get_skeleton_tree(0),
+ [([0], 0.0), ([1], 0.0), ([2], 4.0)])
+
+ # remove_maximal_simplex test
+ self.assertEqual(st.get_coface_tree([0, 1, 2], 1), [])
+ st.remove_maximal_simplex([0, 1, 2])
+ self.assertEqual(st.get_skeleton_tree(2),
+ [([0, 1], 0.0), ([0, 2], 4.0), ([0], 0.0),
+ ([1, 2], 4.0), ([1], 0.0), ([2], 4.0)])
+ self.assertFalse(st.find([0, 1, 2]))
+ self.assertTrue(st.find([0, 1]))
+ self.assertTrue(st.find([0, 2]))
+ self.assertTrue(st.find([0]))
+ self.assertTrue(st.find([1]))
+ self.assertTrue(st.find([2]))
if __name__ == '__main__':
- unittest.main() \ No newline at end of file
+ unittest.main()