summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py')
-rw-r--r--src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
index 8da1f7fe..cf7df6de 100644
--- a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
+++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
@@ -47,7 +47,7 @@ class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
Constructor for the LowerStarSimplexTreeLayer class
Parameters:
- simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n = number of vertices
+ simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n = number of vertices. Note that its filtration values are modified in each call of the class.
dimensions (List[int]): homology dimensions
min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
"""
@@ -76,8 +76,8 @@ class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1])
min_pers = self.min_persistence[idx_dim]
if min_pers >= 0:
- persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel()
- self.dgms.append((tf.gather(finite_dgm, indices=persistent_indices), essential_dgm))
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
else:
self.dgms.append((finite_dgm, essential_dgm))
return self.dgms