summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/cubical_layer.py
blob: 99d02d66d2b4e5e51c1fe68c50ed67b77c35aa92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy               as np
import tensorflow          as tf
from ..cubical_complex  import CubicalComplex

######################
# Cubical filtration #
######################

# The parameters of the model are the pixel values.

def _Cubical(Xflat, Xdim, dimensions):
    # Parameters: Xflat (flattened image),
    #             Xdim (shape of non-flattened image)
    #             dimensions (homology dimensions)

    # Compute the persistence pairs with Gudhi
    # We reverse the dimensions because CubicalComplex uses Fortran ordering
    cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat)
    cc.compute_persistence()

    # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices    
    cof_pp = cc.cofaces_of_persistence_pairs()
    
    L_cofs = []
    for dim in dimensions:

        try:
            cof = cof_pp[0][dim]
        except IndexError:
            cof = np.array([])

        L_cofs.append(np.array(cof, dtype=np.int32))

    return L_cofs

class CubicalLayer(tf.keras.layers.Layer):
    """
    TensorFlow layer for computing cubical persistence out of a cubical complex
    """
    def __init__(self, dimensions, min_persistence=None, **kwargs):
        """
        Constructor for the CubicalLayer class

        Parameters:
            dimensions (List[int]): homology dimensions
            min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
        """
        super().__init__(dynamic=True, **kwargs)
        self.dimensions = dimensions
        self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions)
        assert len(self.min_persistence) == len(self.dimensions)

    def call(self, X):
        """
        Compute persistence diagram associated to a cubical complex filtered by some pixel values 

        Parameters:
            X (TensorFlow variable): pixel values of the cubical complex

        Returns:
            dgms (list of TensorFlow variables): list of cubical persistence diagrams of length self.dimensions, where each element contains a finite persistence diagram of shape [num_finite_points, 2]
        """
        # Compute pixels associated to positive and negative simplices 
        # Don't compute gradient for this operation
        Xflat = tf.reshape(X, [-1])
        Xdim = X.shape
        indices_list = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
        # Get persistence diagram by simply picking the corresponding entries in the image
        self.dgms = [tf.reshape(tf.gather(Xflat, indices), [-1,2]) for indices in indices_list]
        for idx_dim in range(len(self.min_persistence)):
            min_pers = self.min_persistence[idx_dim]
            if min_pers >= 0:
                finite_dgm = self.dgms[idx_dim]
                persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
                self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices)
        return self.dgms