summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/perslay.py
blob: 9976c5f37c803ccc7eb6d4427c28e4aecf8add9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s):       Mathieu Carrière
#
# Copyright (C) 2021 Inria
#
# Modification(s):
#   - YYYY/MM Author: Description of the modification

import tensorflow as tf
import math

class GridPerslayWeight(tf.keras.layers.Layer):
    """
    This is a class for computing a differentiable weight function for persistence diagram points. This function is defined from an array that contains its values on a 2D grid.
    """
    def __init__(self, grid, grid_bnds, **kwargs):
        """
        Constructor for the GridPerslayWeight class.
  
        Parameters:
            grid (n x n numpy array): grid of values.
            grid_bnds (2 x 2 numpy array): boundaries of the grid, of the form [[min_x, max_x], [min_y, max_y]].
        """
        super().__init__(dynamic=True, **kwargs)
        self.grid = tf.Variable(initial_value=grid, trainable=True)
        self.grid_bnds = grid_bnds
    
    def build(self, input_shape):
        return self

    def call(self, diagrams):
        """
        Apply GridPerslayWeight on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            weight (n x None): ragged tensor containing the weights of the points in the n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
        """
        grid_shape = self.grid.shape
        indices = []
        for dim in range(2):
            [m,M] = self.grid_bnds[dim]
            coords = tf.expand_dims(diagrams[:,:,dim],-1)
            ids = grid_shape[dim]*(coords-m)/(M-m)
            indices.append(tf.cast(ids, tf.int32))
        weight = tf.gather_nd(params=self.grid, indices=tf.concat(indices, axis=2))
        return weight
    
class GaussianMixturePerslayWeight(tf.keras.layers.Layer):
    """
    This is a class for computing a differentiable weight function for persistence diagram points. This function is defined from a mixture of Gaussian functions.
    """
    def __init__(self, gaussians, **kwargs):
        """
        Constructor for the GridPerslayWeight class.
  
        Parameters:
            gaussians (4 x n numpy array): parameters of the n Gaussian functions, of the form transpose([[mu_x^1, mu_y^1, sigma_x^1, sigma_y^1], ..., [mu_x^n, mu_y^n, sigma_x^n, sigma_y^n]]). 
        """
        super().__init__(dynamic=True, **kwargs)
        self.W = tf.Variable(initial_value=gaussians, trainable=True)

    def build(self, input_shape):
        return self
        
    def call(self, diagrams):
        """
        Apply GaussianMixturePerslayWeight on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            weight (n x None): ragged tensor containing the weights of the points in the n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
        """
        means     = tf.expand_dims(tf.expand_dims(self.W[:2,:],0),0)
        variances = tf.expand_dims(tf.expand_dims(self.W[2:,:],0),0)
        diagrams  = tf.expand_dims(diagrams, -1)
        dists     = tf.math.multiply(tf.math.square(diagrams-means), 1/tf.math.square(variances))
        weight    = tf.math.reduce_sum(tf.math.exp(tf.math.reduce_sum(-dists, axis=2)), axis=2)
        return weight
    
class PowerPerslayWeight(tf.keras.layers.Layer):
    """
    This is a class for computing a differentiable weight function for persistence diagram points. This function is defined as a constant multiplied by the distance to the diagonal of the persistence diagram point raised to some power.
    """
    def __init__(self, constant, power, **kwargs):
        """
        Constructor for the PowerPerslayWeight class.
  
        Parameters:
            constant (float): constant value.
            power (float): power applied to the distance to the diagonal. 
        """
        super().__init__(dynamic=True, **kwargs)
        self.constant = tf.Variable(initial_value=constant, trainable=True)
        self.power = power
        
    def build(self, input_shape):
        return self
    
    def call(self, diagrams):
        """
        Apply PowerPerslayWeight on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            weight (n x None): ragged tensor containing the weights of the points in the n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.
        """
        weight = self.constant * tf.math.pow(tf.math.abs(diagrams[:,:,1]-diagrams[:,:,0]), self.power)
        return weight
    

class GaussianPerslayPhi(tf.keras.layers.Layer):
    """
    This is a class for computing a transformation function for persistence diagram points. This function turns persistence diagram points into 2D Gaussian functions centered on the points, that are then evaluated on a regular 2D grid.
    """
    def __init__(self, image_size, image_bnds, variance, **kwargs):
        """
        Constructor for the GaussianPerslayPhi class.
  
        Parameters:
            image_size (int numpy array): number of grid elements on each grid axis, of the form [n_x, n_y].
            image_bnds (2 x 2 numpy array): boundaries of the grid, of the form [[min_x, max_x], [min_y, max_y]].
            variance (float): variance of the Gaussian functions. 
        """
        super().__init__(dynamic=True, **kwargs)
        self.image_size = image_size
        self.image_bnds = image_bnds
        self.variance   = tf.Variable(initial_value=variance, trainable=True)
        
    def build(self, input_shape):
        return self
        
    def call(self, diagrams):
        """
        Apply GaussianPerslayPhi on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            output (n x None x image_size x image_size x 1): ragged tensor containing the evaluations on the 2D grid of the 2D Gaussian functions corresponding to the persistence diagram points, in the form of a 2D image with 1 channel that can be processed with, e.g., convolutional layers. The second dimension is ragged since persistence diagrams can have different numbers of points.
            output_shape (int numpy array): shape of the output tensor.
        """
        diagrams_d = tf.concat([diagrams[:,:,0:1], diagrams[:,:,1:2]-diagrams[:,:,0:1]], axis=2)
        step = [(self.image_bnds[i][1]-self.image_bnds[i][0])/self.image_size[i] for i in range(2)]
        coords = [tf.range(self.image_bnds[i][0], self.image_bnds[i][1], step[i]) for i in range(2)]
        M = tf.meshgrid(*coords)
        mu = tf.concat([tf.expand_dims(tens, 0) for tens in M], axis=0)
        for _ in range(2):
            diagrams_d = tf.expand_dims(diagrams_d,-1)
        dists = tf.math.square(diagrams_d-mu) / (2*tf.math.square(self.variance))
        gauss = tf.math.exp(tf.math.reduce_sum(-dists, axis=2)) / (2*math.pi*tf.math.square(self.variance))
        output = tf.expand_dims(gauss,-1)
        output_shape = M[0].shape + tuple([1])
        return output, output_shape
     
class TentPerslayPhi(tf.keras.layers.Layer):
    """
    This is a class for computing a transformation function for persistence diagram points. This function turns persistence diagram points into 1D tent functions (linearly increasing on the first half of the bar corresponding to the point from zero to half of the bar length, linearly decreasing on the second half and zero elsewhere) centered on the points, that are then evaluated on a regular 1D grid.
    """
    def __init__(self, samples, **kwargs):
        """
        Constructor for the GaussianPerslayPhi class.
  
        Parameters:
            samples (float numpy array): grid elements on which to evaluate the tent functions, of the form [x_1, ..., x_n].
        """
        super().__init__(dynamic=True, **kwargs)
        self.samples   = tf.Variable(initial_value=samples, trainable=True)
        
    def build(self, input_shape):
        return self
        
    def call(self, diagrams):
        """
        Apply TentPerslayPhi on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            output (n x None x num_samples): ragged tensor containing the evaluations on the 1D grid of the 1D tent functions corresponding to the persistence diagram points. The second dimension is ragged since persistence diagrams can have different numbers of points.
            output_shape (int numpy array): shape of the output tensor.
        """
        samples_d = tf.expand_dims(tf.expand_dims(self.samples,0),0)
        xs, ys = diagrams[:,:,0:1], diagrams[:,:,1:2]
        output = tf.math.maximum(.5*(ys-xs) - tf.math.abs(samples_d-.5*(ys+xs)), tf.constant([0.]))
        output_shape = self.samples.shape
        return output, output_shape
    
class FlatPerslayPhi(tf.keras.layers.Layer):
    """
    This is a class for computing a transformation function for persistence diagram points. This function turns persistence diagram points into 1D constant functions (that evaluate to half of the bar length on the bar corresponding to the point and zero elsewhere), that are then evaluated on a regular 1D grid.
    """
    def __init__(self, samples, theta, **kwargs):
        """
        Constructor for the FlatPerslayPhi class.
  
        Parameters:
            samples (float numpy array): grid elements on which to evaluate the constant functions, of the form [x_1, ..., x_n].
            theta (float): sigmoid parameter used to approximate the constant function with a differentiable sigmoid function. The bigger the theta, the closer to a constant function the output will be. 
        """
        super().__init__(dynamic=True, **kwargs)
        self.samples = tf.Variable(initial_value=samples, trainable=True)
        self.theta   = tf.Variable(initial_value=theta,   trainable=True)
        
    def build(self, input_shape):
        return self
        
    def call(self, diagrams):
        """
        Apply FlatPerslayPhi on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            output (n x None x num_samples): ragged tensor containing the evaluations on the 1D grid of the 1D constant functions corresponding to the persistence diagram points. The second dimension is ragged since persistence diagrams can have different numbers of points.
            output_shape (int numpy array): shape of the output tensor.
        """
        samples_d = tf.expand_dims(tf.expand_dims(self.samples,0),0)
        xs, ys = diagrams[:,:,0:1], diagrams[:,:,1:2]
        output = 1./(1.+tf.math.exp(-self.theta*(.5*(ys-xs)-tf.math.abs(samples_d-.5*(ys+xs)))))
        output_shape = self.samples.shape
        return output, output_shape

class Perslay(tf.keras.layers.Layer):
    """
    This is a TensorFlow layer for vectorizing persistence diagrams in a differentiable way within a neural network. This function implements the PersLay equation, see `the corresponding article <http://proceedings.mlr.press/v108/carriere20a.html>`_.
    """
    def __init__(self, weight, phi, perm_op, rho, **kwargs):
        """
        Constructor for the Perslay class.

        Parameters:
            weight (function): weight function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GridPerslayWeight`, :class:`~gudhi.tensorflow.perslay.GaussianMixturePerslayWeight`, :class:`~gudhi.tensorflow.perslay.PowerPerslayWeight`, or a custom TensorFlow function that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams).
            phi (function): transformation function for the persistence diagram points. Can be either :class:`~gudhi.tensorflow.perslay.GaussianPerslayPhi`, :class:`~gudhi.tensorflow.perslay.TentPerslayPhi`, :class:`~gudhi.tensorflow.perslay.FlatPerslayPhi`, or a custom TensorFlow class (that can have trainable parameters) with a method `call` that takes persistence diagrams as argument (represented as an (n x None x 2) ragged tensor, where n is the number of diagrams).
            perm_op (function): permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`, `tf.math.reduce_max`, `tf.math.reduce_min`, or a custom TensorFlow function that takes two arguments: a tensor and an axis on which to apply the permutation invariant operation. If perm_op is the string "topk" (where k is a number), this function will be computed as `tf.math.top_k` with parameter `int(k)`.
            rho (function): postprocessing function that is applied after the permutation invariant operation. Can be any TensorFlow layer.
        """
        super().__init__(dynamic=True, **kwargs)
        self.weight  = weight
        self.phi     = phi
        self.perm_op = perm_op  
        self.rho     = rho

    def build(self, input_shape):
        return self

    def call(self, diagrams):
        """
        Apply Perslay on a ragged tensor containing a list of persistence diagrams.

        Parameters:
            diagrams (n x None x 2): ragged tensor containing n persistence diagrams. The second dimension is ragged since persistence diagrams can have different numbers of points.

        Returns:
            vector (n x output_shape): tensor containing the vectorizations of the persistence diagrams.
        """
        vector, dim = self.phi(diagrams)
        weight = self.weight(diagrams)
        for _ in range(len(dim)):
            weight = tf.expand_dims(weight, -1)
        vector = tf.math.multiply(vector, weight)
          
        permop = self.perm_op
        if type(permop) == str and permop[:3] == 'top':
            k = int(permop[3:])
            vector = vector.to_tensor(default_value=-1e10)
            vector = tf.math.top_k(tf.transpose(vector, perm=[0, 2, 1]), k=k).values
            vector = tf.reshape(vector, [-1,k*dim[0]])
        else:
            vector = permop(vector, axis=1)

        vector = self.rho(vector)
            
        return vector