summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2022-02-02 22:11:04 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2022-02-02 22:11:04 +0100
commit5c00d2dfcf4b0e2835441533f12f195d83652e99 (patch)
tree1b938487aee9c58cd5fbc2aadbe35896b5e47005 /src/python/gudhi
parentc07e645abc27350351af73fa9b24b3d5f881033e (diff)
fixed bugs from the new API
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py8
-rw-r--r--src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py3
-rw-r--r--src/python/gudhi/tensorflow/rips_layer.py4
3 files changed, 7 insertions, 8 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index b16c512f..99d02d66 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -47,7 +47,7 @@ class CubicalLayer(tf.keras.layers.Layer):
"""
super().__init__(dynamic=True, **kwargs)
self.dimensions = dimensions
- self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))]
+ self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions)
assert len(self.min_persistence) == len(self.dimensions)
def call(self, X):
@@ -64,13 +64,13 @@ class CubicalLayer(tf.keras.layers.Layer):
# Don't compute gradient for this operation
Xflat = tf.reshape(X, [-1])
Xdim = X.shape
- indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
+ indices_list = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
# Get persistence diagram by simply picking the corresponding entries in the image
- self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices]
+ self.dgms = [tf.reshape(tf.gather(Xflat, indices), [-1,2]) for indices in indices_list]
for idx_dim in range(len(self.min_persistence)):
min_pers = self.min_persistence[idx_dim]
if min_pers >= 0:
finite_dgm = self.dgms[idx_dim]
- persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel()
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices)
return self.dgms
diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
index e1627944..8da1f7fe 100644
--- a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
+++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
@@ -12,8 +12,7 @@ def _LowerStarSimplexTree(simplextree, filtration, dimensions):
# filtration (function values on the vertices of st),
# dimensions (homology dimensions),
- for s,_ in simplextree.get_filtration():
- simplextree.assign_filtration(s, -1e10)
+ simplextree.reset_filtration(-np.inf, 0)
# Assign new filtration values
for i in range(simplextree.num_vertices()):
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py
index a5f212e3..88d501c1 100644
--- a/src/python/gudhi/tensorflow/rips_layer.py
+++ b/src/python/gudhi/tensorflow/rips_layer.py
@@ -40,7 +40,7 @@ class RipsLayer(tf.keras.layers.Layer):
"""
TensorFlow layer for computing Rips persistence out of a point cloud
"""
- def __init__(self, dimensions, maximum_edge_length=12, min_persistence=None, **kwargs):
+ def __init__(self, dimensions, maximum_edge_length=np.inf, min_persistence=None, **kwargs):
"""
Constructor for the RipsLayer class
@@ -66,7 +66,7 @@ class RipsLayer(tf.keras.layers.Layer):
dgms (list of tuple of TensorFlow variables): list of Rips persistence diagrams of length self.dimensions, where each element of the list is a tuple that contains the finite and essential persistence diagrams of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
"""
# Compute distance matrix
- DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2))
+ DX = tf.norm(tf.expand_dims(X, 1)-tf.expand_dims(X, 0), axis=2)
# Compute vertices associated to positive and negative simplices
# Don't compute gradient for this operation
indices = _Rips(DX.numpy(), self.max_edge, self.dimensions)