diff options
4 files changed, 7 insertions, 7 deletions
diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst index 56bb4492..b8518cdb 100644 --- a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst +++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst @@ -42,7 +42,7 @@ Example of gradient computed from lower-star filtration of a simplex tree .. testoutput:: [2 4] - [-1. 1.] + [-1. 1.] Documentation for LowerStarSimplexTreeLayer ------------------------------------------- diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py index 99d02d66..369b0e54 100644 --- a/src/python/gudhi/tensorflow/cubical_layer.py +++ b/src/python/gudhi/tensorflow/cubical_layer.py @@ -35,7 +35,7 @@ def _Cubical(Xflat, Xdim, dimensions): class CubicalLayer(tf.keras.layers.Layer): """ - TensorFlow layer for computing cubical persistence out of a cubical complex + TensorFlow layer for computing the persistent homology of a cubical complex """ def __init__(self, dimensions, min_persistence=None, **kwargs): """ diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py index 8da1f7fe..cf7df6de 100644 --- a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py +++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py @@ -47,7 +47,7 @@ class LowerStarSimplexTreeLayer(tf.keras.layers.Layer): Constructor for the LowerStarSimplexTreeLayer class Parameters: - simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n = number of vertices + simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n = number of vertices. Note that its filtration values are modified in each call of the class. dimensions (List[int]): homology dimensions min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions) """ @@ -76,8 +76,8 @@ class LowerStarSimplexTreeLayer(tf.keras.layers.Layer): essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1]) min_pers = self.min_persistence[idx_dim] if min_pers >= 0: - persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel() - self.dgms.append((tf.gather(finite_dgm, indices=persistent_indices), essential_dgm)) + persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) + self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm)) else: self.dgms.append((finite_dgm, essential_dgm)) return self.dgms diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py index 88d501c1..7b5edfa3 100644 --- a/src/python/gudhi/tensorflow/rips_layer.py +++ b/src/python/gudhi/tensorflow/rips_layer.py @@ -83,8 +83,8 @@ class RipsLayer(tf.keras.layers.Layer): essential_dgm = tf.zeros([cur_idx[1].shape[0],1]) min_pers = self.min_persistence[idx_dim] if min_pers >= 0: - persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel() - self.dgms.append((tf.gather(finite_dgm, indices=persistent_indices), essential_dgm)) + persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) + self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm)) else: self.dgms.append((finite_dgm, essential_dgm)) return self.dgms |