summaryrefslogtreecommitdiff
path: root/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
blob: 9d7d633f76c3b1f122edad7cff435cab90ef1519 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
:orphan:

.. To get rid of WARNING: document isn't included in any toctree

TensorFlow layer for lower-star persistence on simplex trees
############################################################

.. include:: differentiation_sum.inc

Example of gradient computed from lower-star filtration of a simplex tree
-------------------------------------------------------------------------

.. testcode::

    from gudhi.tensorflow import LowerStarSimplexTreeLayer
    import tensorflow as tf
    import gudhi as gd

    st = gd.SimplexTree()
    st.insert([0, 1]) 
    st.insert([1, 2]) 
    st.insert([2, 3]) 
    st.insert([3, 4]) 
    st.insert([4, 5]) 
    st.insert([5, 6]) 
    st.insert([6, 7]) 
    st.insert([7, 8]) 
    st.insert([8, 9]) 
    st.insert([9, 10]) 

    F = tf.Variable([6.,4.,3.,4.,5.,4.,3.,2.,3.,4.,5.], dtype=tf.float32, trainable=True)
    sl = LowerStarSimplexTreeLayer(simplextree=st, homology_dimensions=[0])

    with tf.GradientTape() as tape:
        dgm = sl.call(F)[0][0]
        loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))

    grads = tape.gradient(loss, [F])
    print(grads[0].indices.numpy())
    print(grads[0].values.numpy())

.. testoutput::

    [2 4]
    [-1.  1.]

Documentation for LowerStarSimplexTreeLayer
-------------------------------------------

.. autoclass:: gudhi.tensorflow.LowerStarSimplexTreeLayer
   :members:
   :special-members: __init__
   :show-inheritance: