diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/python/CMakeLists.txt | 9 | ||||
-rw-r--r-- | src/python/doc/cubical_complex_sum.inc | 25 | ||||
-rw-r--r-- | src/python/doc/cubical_complex_tflow_itf_ref.rst | 42 | ||||
-rw-r--r-- | src/python/doc/differentiation_sum.inc | 11 | ||||
-rw-r--r-- | src/python/doc/installation.rst | 6 | ||||
-rw-r--r-- | src/python/doc/ls_simplex_tree_tflow_itf_ref.rst | 66 | ||||
-rw-r--r-- | src/python/doc/rips_complex_sum.inc | 5 | ||||
-rw-r--r-- | src/python/doc/rips_complex_tflow_itf_ref.rst | 41 | ||||
-rw-r--r-- | src/python/doc/simplex_tree_sum.inc | 23 | ||||
-rw-r--r-- | src/python/gudhi/tensorflow/__init__.py | 5 | ||||
-rw-r--r-- | src/python/gudhi/tensorflow/cubical_layer.py | 66 | ||||
-rw-r--r-- | src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py | 77 | ||||
-rw-r--r-- | src/python/gudhi/tensorflow/rips_layer.py | 75 | ||||
-rw-r--r-- | src/python/test/test_diff.py | 67 |
14 files changed, 494 insertions, 24 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 4a017251..8e8bf59a 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -67,6 +67,7 @@ if(PYTHONINTERP_FOUND) set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ") # Modules that should not be auto-imported in __init__.py set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ") + set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'tensorflow', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'weighted_rips_complex', ") @@ -269,7 +270,8 @@ if(PYTHONINTERP_FOUND) # Other .py files file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/") - file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") + file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") + file(COPY "gudhi/tensorflow" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/clustering" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py") file(COPY "gudhi/weighted_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") @@ -535,6 +537,11 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_representations) endif() + # Differentiation + if(TENSORFLOW_FOUND) + add_gudhi_py_test(test_diff) + endif() + # Time Delay add_gudhi_py_test(test_time_delay) diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc index 87db184d..90ec9fc2 100644 --- a/src/python/doc/cubical_complex_sum.inc +++ b/src/python/doc/cubical_complex_sum.inc @@ -1,14 +1,17 @@ .. table:: :widths: 30 40 30 - +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+ - | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko | - | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | | - | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 | - | :figclass: align-center | | | - | | | :License: MIT | - | | | | - +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+ - | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` | - | | * :doc:`periodic_cubical_complex_ref` | - +--------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+ + +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ + | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko | + | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | | + | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 | + | :figclass: align-center | | | + | | | :License: MIT | + | | | | + +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ + | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` | + | | * :doc:`periodic_cubical_complex_ref` | + +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ + | | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ | + | | | | + +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+ diff --git a/src/python/doc/cubical_complex_tflow_itf_ref.rst b/src/python/doc/cubical_complex_tflow_itf_ref.rst new file mode 100644 index 00000000..a907dfce --- /dev/null +++ b/src/python/doc/cubical_complex_tflow_itf_ref.rst @@ -0,0 +1,42 @@ +:orphan: + +.. To get rid of WARNING: document isn't included in any toctree + +TensorFlow layer for cubical persistence +######################################## + +.. include:: differentiation_sum.inc + +Example of gradient computed from cubical persistence +----------------------------------------------------- + +.. testcode:: + + from gudhi.tensorflow import * + import numpy as np + import tensorflow as tf + + Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + cl = CubicalLayer(dimension=0) + + with tf.GradientTape() as tape: + dgm = cl.call(X) + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + + grads = tape.gradient(loss, [X]) + print(grads[0].numpy()) + +.. testoutput:: + + [[ 0. 0. 0. ] + [ 0. 0.5 0. ] + [ 0. 0. -0.5]] + +Documentation for CubicalLayer +------------------------------ + +.. autoclass:: gudhi.tensorflow.CubicalLayer + :members: + :special-members: __init__ + :show-inheritance: diff --git a/src/python/doc/differentiation_sum.inc b/src/python/doc/differentiation_sum.inc new file mode 100644 index 00000000..3dd8e59c --- /dev/null +++ b/src/python/doc/differentiation_sum.inc @@ -0,0 +1,11 @@ +.. list-table:: + :widths: 40 30 30 + :header-rows: 0 + + * - :Since: GUDHI 3.5.0 + - :License: MIT + - :Requires: `TensorFlow <installation.html#tensorflow>`_ + +We provide TensorFlow 2 models that can handle automatic differentiation for the computation of persistence diagrams from complexes available in the Gudhi library. +This includes simplex trees, cubical complexes and Vietoris-Rips complexes. Detailed example on how to use these layers in practice are available +in the following `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-optimization.ipynb>`_. diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index 35c344e3..25eb7a90 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -393,7 +393,11 @@ mathematics, science, and engineering. TensorFlow ---------- -`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests. +The :doc:`cubical complex </cubical_complex_tflow_itf_ref>`, :doc:`simplex tree </ls_simplex_tree_tflow_itf_ref>` +and :doc:`Rips complex </rips_complex_tflow_itf_ref>` modules require `TensorFlow <https://www.tensorflow.org>`_ +for incorporating them in neural nets. + +`TensorFlow <https://www.tensorflow.org>`_ is also used in some automatic differentiation tests. Bug reports and contributions ***************************** diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst new file mode 100644 index 00000000..26cf1ff2 --- /dev/null +++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst @@ -0,0 +1,66 @@ +:orphan: + +.. To get rid of WARNING: document isn't included in any toctree + +TensorFlow layer for lower-star persistence on simplex trees +############################################################ + +.. include:: differentiation_sum.inc + +Example of gradient computed from lower-star filtration of a simplex tree +------------------------------------------------------------------------- + +.. testcode:: + + from gudhi.tensorflow import * + import numpy as np + import tensorflow as tf + import gudhi as gd + + st = gd.SimplexTree() + st.insert([0]) + st.insert([1]) + st.insert([2]) + st.insert([3]) + st.insert([4]) + st.insert([5]) + st.insert([6]) + st.insert([7]) + st.insert([8]) + st.insert([9]) + st.insert([10]) + st.insert([0, 1]) + st.insert([1, 2]) + st.insert([2, 3]) + st.insert([3, 4]) + st.insert([4, 5]) + st.insert([5, 6]) + st.insert([6, 7]) + st.insert([7, 8]) + st.insert([8, 9]) + st.insert([9, 10]) + + Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32) + F = tf.Variable(initial_value=Finit, trainable=True) + sl = LowerStarSimplexTreeLayer(simplextree=st, dimension=0) + + with tf.GradientTape() as tape: + dgm = sl.call(F) + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + + grads = tape.gradient(loss, [F]) + print(grads[0].indices.numpy()) + print(grads[0].values.numpy()) + +.. testoutput:: + + [2 4] + [-1. 1.] + +Documentation for LowerStarSimplexTreeLayer +------------------------------------------- + +.. autoclass:: gudhi.tensorflow.LowerStarSimplexTreeLayer + :members: + :special-members: __init__ + :show-inheritance: diff --git a/src/python/doc/rips_complex_sum.inc b/src/python/doc/rips_complex_sum.inc index 2cb24990..6931ebee 100644 --- a/src/python/doc/rips_complex_sum.inc +++ b/src/python/doc/rips_complex_sum.inc @@ -11,4 +11,7 @@ | | | | +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+ | * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` | - +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+ + +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+ + | | * :doc:`rips_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ | + | | | | + +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+ diff --git a/src/python/doc/rips_complex_tflow_itf_ref.rst b/src/python/doc/rips_complex_tflow_itf_ref.rst new file mode 100644 index 00000000..7300eba0 --- /dev/null +++ b/src/python/doc/rips_complex_tflow_itf_ref.rst @@ -0,0 +1,41 @@ +:orphan: + +.. To get rid of WARNING: document isn't included in any toctree + +TensorFlow layer for Vietoris-Rips persistence +############################################## + +.. include:: differentiation_sum.inc + +Example of gradient computed from Vietoris-Rips persistence +----------------------------------------------------------- + +.. testcode:: + + from gudhi.tensorflow import * + import numpy as np + import tensorflow as tf + + Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + rl = RipsLayer(maximum_edge_length=2., dimension=0) + + with tf.GradientTape() as tape: + dgm = rl.call(X) + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + + grads = tape.gradient(loss, [X]) + print(grads[0].numpy()) + +.. testoutput:: + + [[-0.5 -0.5] + [ 0.5 0.5]] + +Documentation for RipsLayer +--------------------------- + +.. autoclass:: gudhi.tensorflow.RipsLayer + :members: + :special-members: __init__ + :show-inheritance: diff --git a/src/python/doc/simplex_tree_sum.inc b/src/python/doc/simplex_tree_sum.inc index a8858f16..3ad1292c 100644 --- a/src/python/doc/simplex_tree_sum.inc +++ b/src/python/doc/simplex_tree_sum.inc @@ -1,13 +1,16 @@ .. table:: :widths: 30 40 30 - +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+ - | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria | - | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | | - | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 | - | :figclass: align-center | The data structure is described in | | - | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT | - | | | | - +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+ - | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` | - +----------------------------------------------------------------+------------------------------------------------------------------------------------------------------+ + +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+ + | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria | + | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | | + | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 | + | :figclass: align-center | The data structure is described in | | + | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT | + | | | | + +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+ + | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` | + +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+ + | | * :doc:`ls_simplex_tree_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ | + | | | | + +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+ diff --git a/src/python/gudhi/tensorflow/__init__.py b/src/python/gudhi/tensorflow/__init__.py new file mode 100644 index 00000000..1599cf52 --- /dev/null +++ b/src/python/gudhi/tensorflow/__init__.py @@ -0,0 +1,5 @@ +from .cubical_layer import CubicalLayer +from .lower_star_simplex_tree_layer import LowerStarSimplexTreeLayer +from .rips_layer import RipsLayer + +__all__ = ["LowerStarSimplexTreeLayer", "RipsLayer", "CubicalLayer"] diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py new file mode 100644 index 00000000..e36adec5 --- /dev/null +++ b/src/python/gudhi/tensorflow/cubical_layer.py @@ -0,0 +1,66 @@ +import numpy as np +import tensorflow as tf +from ..cubical_complex import CubicalComplex + +###################### +# Cubical filtration # +###################### + +# The parameters of the model are the pixel values. + +def _Cubical(X, dimension): + # Parameters: X (image), + # dimension (homology dimension) + + # Compute the persistence pairs with Gudhi + cc = CubicalComplex(dimensions=X.shape, top_dimensional_cells=X.flatten()) + cc.persistence() + try: + cof = cc.cofaces_of_persistence_pairs()[0][dimension] + except IndexError: + cof = np.array([]) + + if len(cof) > 0: + # Sort points with distance-to-diagonal + Xs = X.shape + pers = [X[np.unravel_index(cof[idx,1], Xs)] - X[np.unravel_index(cof[idx,0], Xs)] for idx in range(len(cof))] + perm = np.argsort(pers) + cof = cof[perm[::-1]] + + # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices + D = len(Xs) if len(cof) > 0 else 1 + ocof = np.array([0 for _ in range(D*2*cof.shape[0])]) + count = 0 + for idx in range(0,2*cof.shape[0],2): + ocof[D*idx:D*(idx+1)] = np.unravel_index(cof[count,0], Xs) + ocof[D*(idx+1):D*(idx+2)] = np.unravel_index(cof[count,1], Xs) + count += 1 + return np.array(ocof, dtype=np.int32) + +class CubicalLayer(tf.keras.layers.Layer): + """ + TensorFlow layer for computing cubical persistence out of a cubical complex + + Attributes: + dimension (int): homology dimension + """ + def __init__(self, dimension=1, **kwargs): + super().__init__(dynamic=True, **kwargs) + self.dimension = dimension + + def build(self): + super.build() + + def call(self, X): + """ + Compute persistence diagram associated to a cubical complex filtered by some pixel values + + Parameters: + X (TensorFlow variable): pixel values of the cubical complex + """ + # Compute pixels associated to positive and negative simplices + # Don't compute gradient for this operation + indices = tf.stop_gradient(_Cubical(X.numpy(), self.dimension)) + # Get persistence diagram by simply picking the corresponding entries in the image + dgm = tf.reshape(tf.gather_nd(X, tf.reshape(indices, [-1,len(X.shape)])), [-1,2]) + return dgm diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py new file mode 100644 index 00000000..fc963d2f --- /dev/null +++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py @@ -0,0 +1,77 @@ +import numpy as np +import tensorflow as tf + +######################################### +# Lower star filtration on simplex tree # +######################################### + +# The parameters of the model are the vertex function values of the simplex tree. + +def _LowerStarSimplexTree(simplextree, filtration, dimension): + # Parameters: simplextree (simplex tree on which to compute persistence) + # filtration (function values on the vertices of st), + # dimension (homology dimension), + + for s,_ in simplextree.get_filtration(): + simplextree.assign_filtration(s, -1e10) + + # Assign new filtration values + for i in range(simplextree.num_vertices()): + simplextree.assign_filtration([i], filtration[i]) + simplextree.make_filtration_non_decreasing() + + # Compute persistence diagram + dgm = simplextree.persistence() + + # Get vertex pairs for optimization. First, get all simplex pairs + pairs = simplextree.persistence_pairs() + + # Then, loop over all simplex pairs + indices, pers = [], [] + for s1, s2 in pairs: + # Select pairs with good homological dimension and finite lifetime + if len(s1) == dimension+1 and len(s2) > 0: + # Get IDs of the vertices corresponding to the filtration values of the simplices + l1, l2 = np.array(s1), np.array(s2) + i1 = l1[np.argmax(filtration[l1])] + i2 = l2[np.argmax(filtration[l2])] + indices.append(i1) + indices.append(i2) + # Compute lifetime + pers.append(simplextree.filtration(s2)-simplextree.filtration(s1)) + + # Sort vertex pairs wrt lifetime + perm = np.argsort(pers) + indices = np.reshape(indices, [-1,2])[perm][::-1,:].flatten() + + return np.array(indices, dtype=np.int32) + +class LowerStarSimplexTreeLayer(tf.keras.layers.Layer): + """ + TensorFlow layer for computing lower-star persistence out of a simplex tree + + Attributes: + simplextree (gudhi.SimplexTree()): underlying simplex tree + dimension (int): homology dimension + """ + def __init__(self, simplextree, dimension=0, **kwargs): + super().__init__(dynamic=True, **kwargs) + self.dimension = dimension + self.simplextree = simplextree + + def build(self): + super.build() + + def call(self, filtration): + """ + Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree + + Parameters: + F (TensorFlow variable): filter function values over the vertices of the simplex tree + """ + # Don't try to compute gradients for the vertex pairs + indices = tf.stop_gradient(_LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimension)) + # Get persistence diagram + self.dgm = tf.reshape(tf.gather(filtration, indices), [-1,2]) + return self.dgm + diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py new file mode 100644 index 00000000..373e021e --- /dev/null +++ b/src/python/gudhi/tensorflow/rips_layer.py @@ -0,0 +1,75 @@ +import numpy as np +import tensorflow as tf +from ..rips_complex import RipsComplex + +############################ +# Vietoris-Rips filtration # +############################ + +# The parameters of the model are the point coordinates. + +def _Rips(DX, max_edge, dimension): + # Parameters: DX (distance matrix), + # max_edge (maximum edge length for Rips filtration), + # dimension (homology dimension) + + # Compute the persistence pairs with Gudhi + rc = RipsComplex(distance_matrix=DX, max_edge_length=max_edge) + st = rc.create_simplex_tree(max_dimension=dimension+1) + dgm = st.persistence() + pairs = st.persistence_pairs() + + # Retrieve vertices v_a and v_b by picking the ones achieving the maximal + # distance among all pairwise distances between the simplex vertices + indices, pers = [], [] + for s1, s2 in pairs: + if len(s1) == dimension+1 and len(s2) > 0: + l1, l2 = np.array(s1), np.array(s2) + i1 = [l1[v] for v in np.unravel_index(np.argmax(DX[l1,:][:,l1]),[len(l1), len(l1)])] + i2 = [l2[v] for v in np.unravel_index(np.argmax(DX[l2,:][:,l2]),[len(l2), len(l2)])] + indices.append(i1) + indices.append(i2) + pers.append(st.filtration(s2)-st.filtration(s1)) + + # Sort points with distance-to-diagonal + perm = np.argsort(pers) + indices = np.reshape(indices, [-1,4])[perm][::-1,:].flatten() + + return np.array(indices, dtype=np.int32) + +class RipsLayer(tf.keras.layers.Layer): + """ + TensorFlow layer for computing Rips persistence out of a point cloud + + Attributes: + maximum_edge_length (float): maximum edge length for the Rips complex + dimension (int): homology dimension + """ + def __init__(self, maximum_edge_length=12, dimension=1, **kwargs): + super().__init__(dynamic=True, **kwargs) + self.max_edge = maximum_edge_length + self.dimension = dimension + + def build(self): + super.build() + + def call(self, X): + """ + Compute Rips persistence diagram associated to a point cloud + + Parameters: + X (TensorFlow variable): point cloud of shape [number of points, number of dimensions] + """ + # Compute distance matrix + DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2)) + # Compute vertices associated to positive and negative simplices + # Don't compute gradient for this operation + indices = tf.stop_gradient(_Rips(DX.numpy(), self.max_edge, self.dimension)) + # Get persistence diagram by simply picking the corresponding entries in the distance matrix + if self.dimension > 0: + dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(indices, [-1,2])), [-1,2]) + else: + indices = tf.reshape(indices, [-1,2])[1::2,:] + dgm = tf.concat([tf.zeros([indices.shape[0],1]), tf.reshape(tf.gather_nd(DX, indices), [-1,1])], axis=1) + return dgm + diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py new file mode 100644 index 00000000..73a03697 --- /dev/null +++ b/src/python/test/test_diff.py @@ -0,0 +1,67 @@ +from gudhi.tensorflow import * +import numpy as np +import tensorflow as tf +import gudhi as gd + +def test_rips_diff(): + + Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + rl = RipsLayer(maximum_edge_length=2., dimension=0) + + with tf.GradientTape() as tape: + dgm = rl.call(X) + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [X]) + assert np.abs(grads[0].numpy()-np.array([[-.5,-.5],[.5,.5]])).sum() <= 1e-6 + + +def test_cubical_diff(): + + Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32) + X = tf.Variable(initial_value=Xinit, trainable=True) + cl = CubicalLayer(dimension=0) + + with tf.GradientTape() as tape: + dgm = cl.call(X) + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [X]) + assert np.abs(grads[0].numpy()-np.array([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]])).sum() <= 1e-6 + +def test_st_diff(): + + st = gd.SimplexTree() + st.insert([0]) + st.insert([1]) + st.insert([2]) + st.insert([3]) + st.insert([4]) + st.insert([5]) + st.insert([6]) + st.insert([7]) + st.insert([8]) + st.insert([9]) + st.insert([10]) + st.insert([0, 1]) + st.insert([1, 2]) + st.insert([2, 3]) + st.insert([3, 4]) + st.insert([4, 5]) + st.insert([5, 6]) + st.insert([6, 7]) + st.insert([7, 8]) + st.insert([8, 9]) + st.insert([9, 10]) + + Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32) + F = tf.Variable(initial_value=Finit, trainable=True) + sl = LowerStarSimplexTreeLayer(simplextree=st, dimension=0) + + with tf.GradientTape() as tape: + dgm = sl.call(F) + loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) + grads = tape.gradient(loss, [F]) + + assert np.array_equal(np.array(grads[0].indices), np.array([2,4])) + assert np.array_equal(np.array(grads[0].values), np.array([-1,1])) + |