summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/cubical_layer.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2022-02-02 22:11:04 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2022-02-02 22:11:04 +0100
commit5c00d2dfcf4b0e2835441533f12f195d83652e99 (patch)
tree1b938487aee9c58cd5fbc2aadbe35896b5e47005 /src/python/gudhi/tensorflow/cubical_layer.py
parentc07e645abc27350351af73fa9b24b3d5f881033e (diff)
fixed bugs from the new API
Diffstat (limited to 'src/python/gudhi/tensorflow/cubical_layer.py')
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index b16c512f..99d02d66 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -47,7 +47,7 @@ class CubicalLayer(tf.keras.layers.Layer):
"""
super().__init__(dynamic=True, **kwargs)
self.dimensions = dimensions
- self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))]
+ self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions)
assert len(self.min_persistence) == len(self.dimensions)
def call(self, X):
@@ -64,13 +64,13 @@ class CubicalLayer(tf.keras.layers.Layer):
# Don't compute gradient for this operation
Xflat = tf.reshape(X, [-1])
Xdim = X.shape
- indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
+ indices_list = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
# Get persistence diagram by simply picking the corresponding entries in the image
- self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices]
+ self.dgms = [tf.reshape(tf.gather(Xflat, indices), [-1,2]) for indices in indices_list]
for idx_dim in range(len(self.min_persistence)):
min_pers = self.min_persistence[idx_dim]
if min_pers >= 0:
finite_dgm = self.dgms[idx_dim]
- persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel()
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices)
return self.dgms