summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2022-02-11 23:11:26 +0100
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2022-02-11 23:11:26 +0100
commitfb8ce008feadcaf6a936740a3ed54d50970c731c (patch)
treedef07b8b20af7aeb7a46de87753e3ad2331ace02
parentd8905deb600228b704c093838cb6ad339ef49ad6 (diff)
__copy__, __deepcopy__, copy, and copy ctors. Still pb with the doc
-rw-r--r--.github/next_release.md3
-rw-r--r--src/python/gudhi/simplex_tree.pyx54
-rwxr-xr-xsrc/python/test/test_simplex_tree.py102
3 files changed, 140 insertions, 19 deletions
diff --git a/.github/next_release.md b/.github/next_release.md
index e21b25c7..3946404b 100644
--- a/.github/next_release.md
+++ b/.github/next_release.md
@@ -13,6 +13,9 @@ Below is a list of changes made since GUDHI 3.5.0:
- [Representations](https://gudhi.inria.fr/python/latest/representations.html#gudhi.representations.vector_methods.BettiCurve)
- A more flexible Betti curve class capable of computing exact curves
+- [Simplex tree](https://gudhi.inria.fr/python/latest/simplex_tree_ref.html)
+ - `__copy__`, `__deepcopy__`, `copy` and copy constructors
+
- Installation
- Boost &ge; 1.66.0 is now required (was &ge; 1.56.0).
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index 6b3116a4..ed7c3b92 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -30,6 +30,7 @@ cdef class SimplexTree:
# unfortunately 'cdef public Simplex_tree_interface_full_featured* thisptr' is not possible
# Use intptr_t instead to cast the pointer
cdef public intptr_t thisptr
+ cdef bool __thisptr_to_be_deleted
# Get the pointer casted as it should be
cdef Simplex_tree_interface_full_featured* get_ptr(self) nogil:
@@ -38,17 +39,36 @@ cdef class SimplexTree:
cdef Simplex_tree_persistence_interface * pcohptr
# Fake constructor that does nothing but documenting the constructor
- def __init__(self):
+ def __init__(self, other = None, copy = True):
"""SimplexTree constructor.
+ :param other: If `other` is a SimplexTree (default = None), the SimplexTree is constructed from a deep/shallow copy of `other`.
+ :type other: SimplexTree
+ :param copy: If `True`, the copy will be deep and if `False, the copy will be shallow. Default is `True`.
+ :type copy: bool
+ :returns: A simplex tree that is a (deep or shallow) copy of itself.
+ :rtype: SimplexTree
+ :note: copy constructor requires :func:`compute_persistence` to be launched again as the result is not copied.
"""
# The real cython constructor
- def __cinit__(self):
- self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured())
+ def __cinit__(self, other = None, copy = True):
+ cdef SimplexTree ostr
+ if other and type(other) is SimplexTree:
+ ostr = <SimplexTree> other
+ if copy:
+ self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured(dereference(ostr.get_ptr())))
+ else:
+ self.thisptr = ostr.thisptr
+ # Avoid double free - The original is in charge of deletion
+ self.__thisptr_to_be_deleted = False
+ else:
+ self.__thisptr_to_be_deleted = True
+ self.thisptr = <intptr_t>(new Simplex_tree_interface_full_featured())
def __dealloc__(self):
cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr()
- if ptr != NULL:
+ # Avoid double free - The original is in charge of deletion
+ if ptr != NULL and self.__thisptr_to_be_deleted:
del ptr
if self.pcohptr != NULL:
del self.pcohptr
@@ -63,20 +83,34 @@ cdef class SimplexTree:
"""
return self.pcohptr != NULL
- def copy(self):
+ def copy(self, deep=True):
"""
- :returns: A simplex tree that is a deep copy itself.
+ :param deep: If `True`, the copy will be deep and if `False`, the copy will be shallow. Default is `True`.
+ :type deep: bool
+ :returns: A simplex tree that is a (deep or shallow) copy of itself.
:rtype: SimplexTree
+ :note: copy requires :func:`compute_persistence` to be launched again as the result is not copied.
"""
stree = SimplexTree()
cdef Simplex_tree_interface_full_featured* stree_ptr
cdef Simplex_tree_interface_full_featured* self_ptr=self.get_ptr()
- with nogil:
- stree_ptr = new Simplex_tree_interface_full_featured(dereference(self_ptr))
-
- stree.thisptr = <intptr_t>(stree_ptr)
+ if deep:
+ with nogil:
+ stree_ptr = new Simplex_tree_interface_full_featured(dereference(self_ptr))
+
+ stree.thisptr = <intptr_t>(stree_ptr)
+ else:
+ stree.thisptr = self.thisptr
+ # Avoid double free - The original is in charge of deletion
+ stree.__thisptr_to_be_deleted = False
return stree
+ def __copy__(self):
+ return self.copy(deep=False)
+
+ def __deepcopy__(self):
+ return self.copy(deep=True)
+
def filtration(self, simplex):
"""This function returns the filtration value for a given N-simplex in
this simplicial complex, or +infinity if it is not in the complex.
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index dac45288..6db6d8fb 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -448,20 +448,104 @@ def test_persistence_intervals_in_dimension():
# Test empty case
assert st.persistence_intervals_in_dimension(3).shape == (0, 2)
-def test_simplex_tree_copy():
+def test_simplex_tree_deep_copy():
st = SimplexTree()
- st .insert([1,2,3], 0.)
- a = st.copy()
+ st.insert([1, 2, 3], 0.)
+ # persistence is not copied
+ st.compute_persistence()
+
+ st_copy = st.copy(deep=True)
# TODO(VR): when #463 is merged, replace with
- # assert a == st
- assert a.num_vertices() == st.num_vertices()
- assert a.num_simplices() == st.num_simplices()
+ # assert st_copy == st
+ assert st_copy.num_vertices() == st.num_vertices()
+ assert st_copy.num_simplices() == st.num_simplices()
st_filt_list = list(st.get_filtration())
- assert list(a.get_filtration()) == st_filt_list
+ assert list(st_copy.get_filtration()) == st_filt_list
+
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
- a.remove_maximal_simplex([1, 2, 3])
- a_filt_list = list(a.get_filtration())
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
assert len(a_filt_list) < len(st_filt_list)
for a_splx in a_filt_list:
assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+def test_simplex_tree_shallow_copy():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.)
+ # persistence is not copied
+ st.compute_persistence()
+
+ st_copy = st.copy(deep=False)
+ # TODO(VR): when #463 is merged, replace with
+ # assert st_copy == st
+ assert st_copy.num_vertices() == st.num_vertices()
+ assert st_copy.num_simplices() == st.num_simplices()
+ assert list(st_copy.get_filtration()) == list(st.get_filtration())
+
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ st_copy.assign_filtration([1, 2, 3], 2.)
+ assert list(st_copy.get_filtration()) == list(st.get_filtration())
+
+ # test double free
+ del st
+ del st_copy
+
+def test_simplex_tree_deep_copy_constructor():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.)
+ # persistence is not copied
+ st.compute_persistence()
+
+ st_copy = SimplexTree(st, copy = True)
+ # TODO(VR): when #463 is merged, replace with
+ # assert st_copy == st
+ assert st_copy.num_vertices() == st.num_vertices()
+ assert st_copy.num_simplices() == st.num_simplices()
+ st_filt_list = list(st.get_filtration())
+ assert list(st_copy.get_filtration()) == st_filt_list
+
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ st_copy.remove_maximal_simplex([1, 2, 3])
+ a_filt_list = list(st_copy.get_filtration())
+ assert len(a_filt_list) < len(st_filt_list)
+
+ for a_splx in a_filt_list:
+ assert a_splx in st_filt_list
+
+ # test double free
+ del st
+ del st_copy
+
+def test_simplex_tree_shallow_copy():
+ st = SimplexTree()
+ st.insert([1, 2, 3], 0.)
+ # persistence is not copied
+ st.compute_persistence()
+
+ st_copy = SimplexTree(st, copy = False)
+ # TODO(VR): when #463 is merged, replace with
+ # assert st_copy == st
+ assert st_copy.num_vertices() == st.num_vertices()
+ assert st_copy.num_simplices() == st.num_simplices()
+ assert list(st_copy.get_filtration()) == list(st.get_filtration())
+
+ assert st.__is_persistence_defined() == True
+ assert st_copy.__is_persistence_defined() == False
+
+ st_copy.assign_filtration([1, 2, 3], 2.)
+ assert list(st_copy.get_filtration()) == list(st.get_filtration())
+
+ # test double free
+ del st
+ del st_copy