diff options
Diffstat (limited to 'src/python/gudhi/tensorflow/cubical_layer.py')
-rw-r--r-- | src/python/gudhi/tensorflow/cubical_layer.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py index b16c512f..99d02d66 100644 --- a/src/python/gudhi/tensorflow/cubical_layer.py +++ b/src/python/gudhi/tensorflow/cubical_layer.py @@ -47,7 +47,7 @@ class CubicalLayer(tf.keras.layers.Layer): """ super().__init__(dynamic=True, **kwargs) self.dimensions = dimensions - self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))] + self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions) assert len(self.min_persistence) == len(self.dimensions) def call(self, X): @@ -64,13 +64,13 @@ class CubicalLayer(tf.keras.layers.Layer): # Don't compute gradient for this operation Xflat = tf.reshape(X, [-1]) Xdim = X.shape - indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions) + indices_list = _Cubical(Xflat.numpy(), Xdim, self.dimensions) # Get persistence diagram by simply picking the corresponding entries in the image - self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices] + self.dgms = [tf.reshape(tf.gather(Xflat, indices), [-1,2]) for indices in indices_list] for idx_dim in range(len(self.min_persistence)): min_pers = self.min_persistence[idx_dim] if min_pers >= 0: finite_dgm = self.dgms[idx_dim] - persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel() + persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices) return self.dgms |