summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/python/CMakeLists.txt9
-rw-r--r--src/python/doc/cubical_complex_sum.inc25
-rw-r--r--src/python/doc/cubical_complex_tflow_itf_ref.rst42
-rw-r--r--src/python/doc/differentiation_sum.inc11
-rw-r--r--src/python/doc/installation.rst6
-rw-r--r--src/python/doc/ls_simplex_tree_tflow_itf_ref.rst66
-rw-r--r--src/python/doc/rips_complex_sum.inc5
-rw-r--r--src/python/doc/rips_complex_tflow_itf_ref.rst41
-rw-r--r--src/python/doc/simplex_tree_sum.inc23
-rw-r--r--src/python/gudhi/tensorflow/__init__.py5
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py66
-rw-r--r--src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py77
-rw-r--r--src/python/gudhi/tensorflow/rips_layer.py75
-rw-r--r--src/python/test/test_diff.py67
14 files changed, 494 insertions, 24 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 4a017251..8e8bf59a 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -67,6 +67,7 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ")
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'tensorflow', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'weighted_rips_complex', ")
@@ -269,7 +270,8 @@ if(PYTHONINTERP_FOUND)
# Other .py files
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
- file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/tensorflow" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/clustering" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
file(COPY "gudhi/weighted_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
@@ -535,6 +537,11 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_representations)
endif()
+ # Differentiation
+ if(TENSORFLOW_FOUND)
+ add_gudhi_py_test(test_diff)
+ endif()
+
# Time Delay
add_gudhi_py_test(test_time_delay)
diff --git a/src/python/doc/cubical_complex_sum.inc b/src/python/doc/cubical_complex_sum.inc
index 87db184d..90ec9fc2 100644
--- a/src/python/doc/cubical_complex_sum.inc
+++ b/src/python/doc/cubical_complex_sum.inc
@@ -1,14 +1,17 @@
.. table::
:widths: 30 40 30
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
- | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | |
- | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | | |
- | | | :License: MIT |
- | | | |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------+-----------------------------+
- | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
- | | * :doc:`periodic_cubical_complex_ref` |
- +--------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
+ +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
+ | .. figure:: | The cubical complex represents a grid as a cell complex with | :Author: Pawel Dlotko |
+ | ../../doc/Bitmap_cubical_complex/Cubical_complex_representation.png | cells of all dimensions. | |
+ | :alt: Cubical complex representation | | :Since: GUDHI 2.0.0 |
+ | :figclass: align-center | | |
+ | | | :License: MIT |
+ | | | |
+ +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
+ | * :doc:`cubical_complex_user` | * :doc:`cubical_complex_ref` |
+ | | * :doc:`periodic_cubical_complex_ref` |
+ +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
+ | | * :doc:`cubical_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | | | |
+ +--------------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------+
diff --git a/src/python/doc/cubical_complex_tflow_itf_ref.rst b/src/python/doc/cubical_complex_tflow_itf_ref.rst
new file mode 100644
index 00000000..a907dfce
--- /dev/null
+++ b/src/python/doc/cubical_complex_tflow_itf_ref.rst
@@ -0,0 +1,42 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for cubical persistence
+########################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from cubical persistence
+-----------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import *
+ import numpy as np
+ import tensorflow as tf
+
+ Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(dimension=0)
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [X])
+ print(grads[0].numpy())
+
+.. testoutput::
+
+ [[ 0. 0. 0. ]
+ [ 0. 0.5 0. ]
+ [ 0. 0. -0.5]]
+
+Documentation for CubicalLayer
+------------------------------
+
+.. autoclass:: gudhi.tensorflow.CubicalLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/differentiation_sum.inc b/src/python/doc/differentiation_sum.inc
new file mode 100644
index 00000000..3dd8e59c
--- /dev/null
+++ b/src/python/doc/differentiation_sum.inc
@@ -0,0 +1,11 @@
+.. list-table::
+ :widths: 40 30 30
+ :header-rows: 0
+
+ * - :Since: GUDHI 3.5.0
+ - :License: MIT
+ - :Requires: `TensorFlow <installation.html#tensorflow>`_
+
+We provide TensorFlow 2 models that can handle automatic differentiation for the computation of persistence diagrams from complexes available in the Gudhi library.
+This includes simplex trees, cubical complexes and Vietoris-Rips complexes. Detailed example on how to use these layers in practice are available
+in the following `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-optimization.ipynb>`_.
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index 35c344e3..25eb7a90 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -393,7 +393,11 @@ mathematics, science, and engineering.
TensorFlow
----------
-`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests.
+The :doc:`cubical complex </cubical_complex_tflow_itf_ref>`, :doc:`simplex tree </ls_simplex_tree_tflow_itf_ref>`
+and :doc:`Rips complex </rips_complex_tflow_itf_ref>` modules require `TensorFlow <https://www.tensorflow.org>`_
+for incorporating them in neural nets.
+
+`TensorFlow <https://www.tensorflow.org>`_ is also used in some automatic differentiation tests.
Bug reports and contributions
*****************************
diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
new file mode 100644
index 00000000..26cf1ff2
--- /dev/null
+++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
@@ -0,0 +1,66 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for lower-star persistence on simplex trees
+############################################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from lower-star filtration of a simplex tree
+-------------------------------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import *
+ import numpy as np
+ import tensorflow as tf
+ import gudhi as gd
+
+ st = gd.SimplexTree()
+ st.insert([0])
+ st.insert([1])
+ st.insert([2])
+ st.insert([3])
+ st.insert([4])
+ st.insert([5])
+ st.insert([6])
+ st.insert([7])
+ st.insert([8])
+ st.insert([9])
+ st.insert([10])
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32)
+ F = tf.Variable(initial_value=Finit, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, dimension=0)
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [F])
+ print(grads[0].indices.numpy())
+ print(grads[0].values.numpy())
+
+.. testoutput::
+
+ [2 4]
+ [-1. 1.]
+
+Documentation for LowerStarSimplexTreeLayer
+-------------------------------------------
+
+.. autoclass:: gudhi.tensorflow.LowerStarSimplexTreeLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/rips_complex_sum.inc b/src/python/doc/rips_complex_sum.inc
index 2cb24990..6931ebee 100644
--- a/src/python/doc/rips_complex_sum.inc
+++ b/src/python/doc/rips_complex_sum.inc
@@ -11,4 +11,7 @@
| | | |
+----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
| * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` |
- +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
+ | | * :doc:`rips_complex_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
diff --git a/src/python/doc/rips_complex_tflow_itf_ref.rst b/src/python/doc/rips_complex_tflow_itf_ref.rst
new file mode 100644
index 00000000..7300eba0
--- /dev/null
+++ b/src/python/doc/rips_complex_tflow_itf_ref.rst
@@ -0,0 +1,41 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for Vietoris-Rips persistence
+##############################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from Vietoris-Rips persistence
+-----------------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import *
+ import numpy as np
+ import tensorflow as tf
+
+ Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., dimension=0)
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [X])
+ print(grads[0].numpy())
+
+.. testoutput::
+
+ [[-0.5 -0.5]
+ [ 0.5 0.5]]
+
+Documentation for RipsLayer
+---------------------------
+
+.. autoclass:: gudhi.tensorflow.RipsLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/src/python/doc/simplex_tree_sum.inc b/src/python/doc/simplex_tree_sum.inc
index a8858f16..3ad1292c 100644
--- a/src/python/doc/simplex_tree_sum.inc
+++ b/src/python/doc/simplex_tree_sum.inc
@@ -1,13 +1,16 @@
.. table::
:widths: 30 40 30
- +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+
- | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria |
- | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | |
- | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 |
- | :figclass: align-center | The data structure is described in | |
- | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT |
- | | | |
- +----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------+
- | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
- +----------------------------------------------------------------+------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | .. figure:: | The simplex tree is an efficient and flexible data structure for | :Author: Clément Maria |
+ | ../../doc/Simplex_tree/Simplex_tree_representation.png | representing general (filtered) simplicial complexes. | |
+ | :alt: Simplex tree representation | | :Since: GUDHI 2.0.0 |
+ | :figclass: align-center | The data structure is described in | |
+ | | :cite:`boissonnatmariasimplextreealgorithmica` | :License: MIT |
+ | | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | * :doc:`simplex_tree_user` | * :doc:`simplex_tree_ref` |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
+ | | * :doc:`ls_simplex_tree_tflow_itf_ref` | :requires: `TensorFlow <installation.html#tensorflow>`_ |
+ | | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+---------------------------------------------------------+
diff --git a/src/python/gudhi/tensorflow/__init__.py b/src/python/gudhi/tensorflow/__init__.py
new file mode 100644
index 00000000..1599cf52
--- /dev/null
+++ b/src/python/gudhi/tensorflow/__init__.py
@@ -0,0 +1,5 @@
+from .cubical_layer import CubicalLayer
+from .lower_star_simplex_tree_layer import LowerStarSimplexTreeLayer
+from .rips_layer import RipsLayer
+
+__all__ = ["LowerStarSimplexTreeLayer", "RipsLayer", "CubicalLayer"]
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
new file mode 100644
index 00000000..e36adec5
--- /dev/null
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -0,0 +1,66 @@
+import numpy as np
+import tensorflow as tf
+from ..cubical_complex import CubicalComplex
+
+######################
+# Cubical filtration #
+######################
+
+# The parameters of the model are the pixel values.
+
+def _Cubical(X, dimension):
+ # Parameters: X (image),
+ # dimension (homology dimension)
+
+ # Compute the persistence pairs with Gudhi
+ cc = CubicalComplex(dimensions=X.shape, top_dimensional_cells=X.flatten())
+ cc.persistence()
+ try:
+ cof = cc.cofaces_of_persistence_pairs()[0][dimension]
+ except IndexError:
+ cof = np.array([])
+
+ if len(cof) > 0:
+ # Sort points with distance-to-diagonal
+ Xs = X.shape
+ pers = [X[np.unravel_index(cof[idx,1], Xs)] - X[np.unravel_index(cof[idx,0], Xs)] for idx in range(len(cof))]
+ perm = np.argsort(pers)
+ cof = cof[perm[::-1]]
+
+ # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
+ D = len(Xs) if len(cof) > 0 else 1
+ ocof = np.array([0 for _ in range(D*2*cof.shape[0])])
+ count = 0
+ for idx in range(0,2*cof.shape[0],2):
+ ocof[D*idx:D*(idx+1)] = np.unravel_index(cof[count,0], Xs)
+ ocof[D*(idx+1):D*(idx+2)] = np.unravel_index(cof[count,1], Xs)
+ count += 1
+ return np.array(ocof, dtype=np.int32)
+
+class CubicalLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing cubical persistence out of a cubical complex
+
+ Attributes:
+ dimension (int): homology dimension
+ """
+ def __init__(self, dimension=1, **kwargs):
+ super().__init__(dynamic=True, **kwargs)
+ self.dimension = dimension
+
+ def build(self):
+ super.build()
+
+ def call(self, X):
+ """
+ Compute persistence diagram associated to a cubical complex filtered by some pixel values
+
+ Parameters:
+ X (TensorFlow variable): pixel values of the cubical complex
+ """
+ # Compute pixels associated to positive and negative simplices
+ # Don't compute gradient for this operation
+ indices = tf.stop_gradient(_Cubical(X.numpy(), self.dimension))
+ # Get persistence diagram by simply picking the corresponding entries in the image
+ dgm = tf.reshape(tf.gather_nd(X, tf.reshape(indices, [-1,len(X.shape)])), [-1,2])
+ return dgm
diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
new file mode 100644
index 00000000..fc963d2f
--- /dev/null
+++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
@@ -0,0 +1,77 @@
+import numpy as np
+import tensorflow as tf
+
+#########################################
+# Lower star filtration on simplex tree #
+#########################################
+
+# The parameters of the model are the vertex function values of the simplex tree.
+
+def _LowerStarSimplexTree(simplextree, filtration, dimension):
+ # Parameters: simplextree (simplex tree on which to compute persistence)
+ # filtration (function values on the vertices of st),
+ # dimension (homology dimension),
+
+ for s,_ in simplextree.get_filtration():
+ simplextree.assign_filtration(s, -1e10)
+
+ # Assign new filtration values
+ for i in range(simplextree.num_vertices()):
+ simplextree.assign_filtration([i], filtration[i])
+ simplextree.make_filtration_non_decreasing()
+
+ # Compute persistence diagram
+ dgm = simplextree.persistence()
+
+ # Get vertex pairs for optimization. First, get all simplex pairs
+ pairs = simplextree.persistence_pairs()
+
+ # Then, loop over all simplex pairs
+ indices, pers = [], []
+ for s1, s2 in pairs:
+ # Select pairs with good homological dimension and finite lifetime
+ if len(s1) == dimension+1 and len(s2) > 0:
+ # Get IDs of the vertices corresponding to the filtration values of the simplices
+ l1, l2 = np.array(s1), np.array(s2)
+ i1 = l1[np.argmax(filtration[l1])]
+ i2 = l2[np.argmax(filtration[l2])]
+ indices.append(i1)
+ indices.append(i2)
+ # Compute lifetime
+ pers.append(simplextree.filtration(s2)-simplextree.filtration(s1))
+
+ # Sort vertex pairs wrt lifetime
+ perm = np.argsort(pers)
+ indices = np.reshape(indices, [-1,2])[perm][::-1,:].flatten()
+
+ return np.array(indices, dtype=np.int32)
+
+class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing lower-star persistence out of a simplex tree
+
+ Attributes:
+ simplextree (gudhi.SimplexTree()): underlying simplex tree
+ dimension (int): homology dimension
+ """
+ def __init__(self, simplextree, dimension=0, **kwargs):
+ super().__init__(dynamic=True, **kwargs)
+ self.dimension = dimension
+ self.simplextree = simplextree
+
+ def build(self):
+ super.build()
+
+ def call(self, filtration):
+ """
+ Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree
+
+ Parameters:
+ F (TensorFlow variable): filter function values over the vertices of the simplex tree
+ """
+ # Don't try to compute gradients for the vertex pairs
+ indices = tf.stop_gradient(_LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimension))
+ # Get persistence diagram
+ self.dgm = tf.reshape(tf.gather(filtration, indices), [-1,2])
+ return self.dgm
+
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py
new file mode 100644
index 00000000..373e021e
--- /dev/null
+++ b/src/python/gudhi/tensorflow/rips_layer.py
@@ -0,0 +1,75 @@
+import numpy as np
+import tensorflow as tf
+from ..rips_complex import RipsComplex
+
+############################
+# Vietoris-Rips filtration #
+############################
+
+# The parameters of the model are the point coordinates.
+
+def _Rips(DX, max_edge, dimension):
+ # Parameters: DX (distance matrix),
+ # max_edge (maximum edge length for Rips filtration),
+ # dimension (homology dimension)
+
+ # Compute the persistence pairs with Gudhi
+ rc = RipsComplex(distance_matrix=DX, max_edge_length=max_edge)
+ st = rc.create_simplex_tree(max_dimension=dimension+1)
+ dgm = st.persistence()
+ pairs = st.persistence_pairs()
+
+ # Retrieve vertices v_a and v_b by picking the ones achieving the maximal
+ # distance among all pairwise distances between the simplex vertices
+ indices, pers = [], []
+ for s1, s2 in pairs:
+ if len(s1) == dimension+1 and len(s2) > 0:
+ l1, l2 = np.array(s1), np.array(s2)
+ i1 = [l1[v] for v in np.unravel_index(np.argmax(DX[l1,:][:,l1]),[len(l1), len(l1)])]
+ i2 = [l2[v] for v in np.unravel_index(np.argmax(DX[l2,:][:,l2]),[len(l2), len(l2)])]
+ indices.append(i1)
+ indices.append(i2)
+ pers.append(st.filtration(s2)-st.filtration(s1))
+
+ # Sort points with distance-to-diagonal
+ perm = np.argsort(pers)
+ indices = np.reshape(indices, [-1,4])[perm][::-1,:].flatten()
+
+ return np.array(indices, dtype=np.int32)
+
+class RipsLayer(tf.keras.layers.Layer):
+ """
+ TensorFlow layer for computing Rips persistence out of a point cloud
+
+ Attributes:
+ maximum_edge_length (float): maximum edge length for the Rips complex
+ dimension (int): homology dimension
+ """
+ def __init__(self, maximum_edge_length=12, dimension=1, **kwargs):
+ super().__init__(dynamic=True, **kwargs)
+ self.max_edge = maximum_edge_length
+ self.dimension = dimension
+
+ def build(self):
+ super.build()
+
+ def call(self, X):
+ """
+ Compute Rips persistence diagram associated to a point cloud
+
+ Parameters:
+ X (TensorFlow variable): point cloud of shape [number of points, number of dimensions]
+ """
+ # Compute distance matrix
+ DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2))
+ # Compute vertices associated to positive and negative simplices
+ # Don't compute gradient for this operation
+ indices = tf.stop_gradient(_Rips(DX.numpy(), self.max_edge, self.dimension))
+ # Get persistence diagram by simply picking the corresponding entries in the distance matrix
+ if self.dimension > 0:
+ dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(indices, [-1,2])), [-1,2])
+ else:
+ indices = tf.reshape(indices, [-1,2])[1::2,:]
+ dgm = tf.concat([tf.zeros([indices.shape[0],1]), tf.reshape(tf.gather_nd(DX, indices), [-1,1])], axis=1)
+ return dgm
+
diff --git a/src/python/test/test_diff.py b/src/python/test/test_diff.py
new file mode 100644
index 00000000..73a03697
--- /dev/null
+++ b/src/python/test/test_diff.py
@@ -0,0 +1,67 @@
+from gudhi.tensorflow import *
+import numpy as np
+import tensorflow as tf
+import gudhi as gd
+
+def test_rips_diff():
+
+ Xinit = np.array([[1.,1.],[2.,2.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ rl = RipsLayer(maximum_edge_length=2., dimension=0)
+
+ with tf.GradientTape() as tape:
+ dgm = rl.call(X)
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert np.abs(grads[0].numpy()-np.array([[-.5,-.5],[.5,.5]])).sum() <= 1e-6
+
+
+def test_cubical_diff():
+
+ Xinit = np.array([[0.,2.,2.],[2.,2.,2.],[2.,2.,1.]], dtype=np.float32)
+ X = tf.Variable(initial_value=Xinit, trainable=True)
+ cl = CubicalLayer(dimension=0)
+
+ with tf.GradientTape() as tape:
+ dgm = cl.call(X)
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [X])
+ assert np.abs(grads[0].numpy()-np.array([[0.,0.,0.],[0.,.5,0.],[0.,0.,-.5]])).sum() <= 1e-6
+
+def test_st_diff():
+
+ st = gd.SimplexTree()
+ st.insert([0])
+ st.insert([1])
+ st.insert([2])
+ st.insert([3])
+ st.insert([4])
+ st.insert([5])
+ st.insert([6])
+ st.insert([7])
+ st.insert([8])
+ st.insert([9])
+ st.insert([10])
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ Finit = np.array([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=np.float32)
+ F = tf.Variable(initial_value=Finit, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, dimension=0)
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+ grads = tape.gradient(loss, [F])
+
+ assert np.array_equal(np.array(grads[0].indices), np.array([2,4]))
+ assert np.array_equal(np.array(grads[0].values), np.array([-1,1]))
+