summaryrefslogtreecommitdiff
path: root/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
diff options
context:
space:
mode:
authorVincent Rouvreau <vincent.rouvreau@inria.fr>2022-08-05 10:18:19 +0200
committerVincent Rouvreau <vincent.rouvreau@inria.fr>2022-08-05 10:18:19 +0200
commit2b4bf47e209225a56687b2a7fa65b27ef4b00ab2 (patch)
treee876553cc2f8c18e0f504a33b26230d0e3e491ed /src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
parent3253e504a0564bc75ffd4b1351e800593ffefd0f (diff)
parent7fa45f4f0c7fb89abf64bc61b26a6201ace16a7a (diff)
Merge master and fix conflicts
Diffstat (limited to 'src/python/doc/ls_simplex_tree_tflow_itf_ref.rst')
-rw-r--r--src/python/doc/ls_simplex_tree_tflow_itf_ref.rst53
1 files changed, 53 insertions, 0 deletions
diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
new file mode 100644
index 00000000..9d7d633f
--- /dev/null
+++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
@@ -0,0 +1,53 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+TensorFlow layer for lower-star persistence on simplex trees
+############################################################
+
+.. include:: differentiation_sum.inc
+
+Example of gradient computed from lower-star filtration of a simplex tree
+-------------------------------------------------------------------------
+
+.. testcode::
+
+ from gudhi.tensorflow import LowerStarSimplexTreeLayer
+ import tensorflow as tf
+ import gudhi as gd
+
+ st = gd.SimplexTree()
+ st.insert([0, 1])
+ st.insert([1, 2])
+ st.insert([2, 3])
+ st.insert([3, 4])
+ st.insert([4, 5])
+ st.insert([5, 6])
+ st.insert([6, 7])
+ st.insert([7, 8])
+ st.insert([8, 9])
+ st.insert([9, 10])
+
+ F = tf.Variable([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=tf.float32, trainable=True)
+ sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])
+
+ with tf.GradientTape() as tape:
+ dgm = sl.call(F)[0][0]
+ loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
+
+ grads = tape.gradient(loss, [F])
+ print(grads[0].indices.numpy())
+ print(grads[0].values.numpy())
+
+.. testoutput::
+
+ [2 4]
+ [-1. 1.]
+
+Documentation for LowerStarSimplexTreeLayer
+-------------------------------------------
+
+.. autoclass:: gudhi.tensorflow.LowerStarSimplexTreeLayer
+ :members:
+ :special-members: __init__
+ :show-inheritance: