summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
blob: fc963d2f5b57788f5bde63dcf828af3a05d5b1fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy               as np
import tensorflow          as tf

#########################################
# Lower star filtration on simplex tree #
#########################################

# The parameters of the model are the vertex function values of the simplex tree.

def _LowerStarSimplexTree(simplextree, filtration, dimension):
    # Parameters: simplextree (simplex tree on which to compute persistence)
    #             filtration (function values on the vertices of st),
    #             dimension (homology dimension),
    
    for s,_ in simplextree.get_filtration():
        simplextree.assign_filtration(s, -1e10)

    # Assign new filtration values
    for i in range(simplextree.num_vertices()):
        simplextree.assign_filtration([i], filtration[i])
    simplextree.make_filtration_non_decreasing()
    
    # Compute persistence diagram
    dgm = simplextree.persistence()
    
    # Get vertex pairs for optimization. First, get all simplex pairs
    pairs = simplextree.persistence_pairs()
    
    # Then, loop over all simplex pairs
    indices, pers = [], []
    for s1, s2 in pairs:
        # Select pairs with good homological dimension and finite lifetime
        if len(s1) == dimension+1 and len(s2) > 0:
            # Get IDs of the vertices corresponding to the filtration values of the simplices
            l1, l2 = np.array(s1), np.array(s2)
            i1 = l1[np.argmax(filtration[l1])]
            i2 = l2[np.argmax(filtration[l2])]
            indices.append(i1)
            indices.append(i2)
            # Compute lifetime
            pers.append(simplextree.filtration(s2)-simplextree.filtration(s1))
    
    # Sort vertex pairs wrt lifetime
    perm = np.argsort(pers)
    indices = np.reshape(indices, [-1,2])[perm][::-1,:].flatten()
    
    return np.array(indices, dtype=np.int32)

class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
    """
    TensorFlow layer for computing lower-star persistence out of a simplex tree

    Attributes:
        simplextree (gudhi.SimplexTree()): underlying simplex tree 
        dimension (int): homology dimension
    """
    def __init__(self, simplextree, dimension=0, **kwargs):
        super().__init__(dynamic=True, **kwargs)
        self.dimension   = dimension
        self.simplextree = simplextree
    
    def build(self):
        super.build()
    
    def call(self, filtration):
        """
        Compute lower-star persistence diagram associated to a function defined on the vertices of the simplex tree

        Parameters:
            F (TensorFlow variable): filter function values over the vertices of the simplex tree
        """
        # Don't try to compute gradients for the vertex pairs
        indices = tf.stop_gradient(_LowerStarSimplexTree(self.simplextree, filtration.numpy(), self.dimension))
        # Get persistence diagram
        self.dgm = tf.reshape(tf.gather(filtration, indices), [-1,2]) 
        return self.dgm