From 5c00d2dfcf4b0e2835441533f12f195d83652e99 Mon Sep 17 00:00:00 2001 From: MathieuCarriere Date: Wed, 2 Feb 2022 22:11:04 +0100 Subject: fixed bugs from the new API --- src/python/gudhi/tensorflow/cubical_layer.py | 8 ++++---- src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py | 3 +-- src/python/gudhi/tensorflow/rips_layer.py | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) (limited to 'src/python/gudhi') diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py index b16c512f..99d02d66 100644 --- a/src/python/gudhi/tensorflow/cubical_layer.py +++ b/src/python/gudhi/tensorflow/cubical_layer.py @@ -47,7 +47,7 @@ class CubicalLayer(tf.keras.layers.Layer): """ super().__init__(dynamic=True, **kwargs) self.dimensions = dimensions - self.min_persistence = min_persistence if min_persistence != None else [0. for _ in range(len(self.dimensions))] + self.min_persistence = min_persistence if min_persistence != None else [0.] * len(self.dimensions) assert len(self.min_persistence) == len(self.dimensions) def call(self, X): @@ -64,13 +64,13 @@ class CubicalLayer(tf.keras.layers.Layer): # Don't compute gradient for this operation Xflat = tf.reshape(X, [-1]) Xdim = X.shape - indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions) + indices_list = _Cubical(Xflat.numpy(), Xdim, self.dimensions) # Get persistence diagram by simply picking the corresponding entries in the image - self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices] + self.dgms = [tf.reshape(tf.gather(Xflat, indices), [-1,2]) for indices in indices_list] for idx_dim in range(len(self.min_persistence)): min_pers = self.min_persistence[idx_dim] if min_pers >= 0: finite_dgm = self.dgms[idx_dim] - persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel() + persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices) return self.dgms diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py index e1627944..8da1f7fe 100644 --- a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py +++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py @@ -12,8 +12,7 @@ def _LowerStarSimplexTree(simplextree, filtration, dimensions): # filtration (function values on the vertices of st), # dimensions (homology dimensions), - for s,_ in simplextree.get_filtration(): - simplextree.assign_filtration(s, -1e10) + simplextree.reset_filtration(-np.inf, 0) # Assign new filtration values for i in range(simplextree.num_vertices()): diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py index a5f212e3..88d501c1 100644 --- a/src/python/gudhi/tensorflow/rips_layer.py +++ b/src/python/gudhi/tensorflow/rips_layer.py @@ -40,7 +40,7 @@ class RipsLayer(tf.keras.layers.Layer): """ TensorFlow layer for computing Rips persistence out of a point cloud """ - def __init__(self, dimensions, maximum_edge_length=12, min_persistence=None, **kwargs): + def __init__(self, dimensions, maximum_edge_length=np.inf, min_persistence=None, **kwargs): """ Constructor for the RipsLayer class @@ -66,7 +66,7 @@ class RipsLayer(tf.keras.layers.Layer): dgms (list of tuple of TensorFlow variables): list of Rips persistence diagrams of length self.dimensions, where each element of the list is a tuple that contains the finite and essential persistence diagrams of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively """ # Compute distance matrix - DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2)) + DX = tf.norm(tf.expand_dims(X, 1)-tf.expand_dims(X, 0), axis=2) # Compute vertices associated to positive and negative simplices # Don't compute gradient for this operation indices = _Rips(DX.numpy(), self.max_edge, self.dimensions) -- cgit v1.2.3