summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2022-02-28 08:41:38 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2022-02-28 08:41:38 +0100
commitcbba4bf2005ce129691d358a2d7475c5132e39e0 (patch)
treeaaa2da3600a05521986e6bb5e8512be3c5e3c423
parent3d78c677dfe4aa8515934ab2494fd6faaad0bc67 (diff)
changed doc + added tensorflow indexing
-rw-r--r--src/python/doc/ls_simplex_tree_tflow_itf_ref.rst2
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py2
-rw-r--r--src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py6
-rw-r--r--src/python/gudhi/tensorflow/rips_layer.py4
4 files changed, 7 insertions, 7 deletions
diff --git a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
index 56bb4492..b8518cdb 100644
--- a/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
+++ b/src/python/doc/ls_simplex_tree_tflow_itf_ref.rst
@@ -42,7 +42,7 @@ Example of gradient computed from lower-star filtration of a simplex tree
.. testoutput::
[2 4]
- [-1. 1.]
+ [-1. 1.]
Documentation for LowerStarSimplexTreeLayer
-------------------------------------------
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index 99d02d66..369b0e54 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -35,7 +35,7 @@ def _Cubical(Xflat, Xdim, dimensions):
class CubicalLayer(tf.keras.layers.Layer):
"""
- TensorFlow layer for computing cubical persistence out of a cubical complex
+ TensorFlow layer for computing the persistent homology of a cubical complex
"""
def __init__(self, dimensions, min_persistence=None, **kwargs):
"""
diff --git a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
index 8da1f7fe..cf7df6de 100644
--- a/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
+++ b/src/python/gudhi/tensorflow/lower_star_simplex_tree_layer.py
@@ -47,7 +47,7 @@ class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
Constructor for the LowerStarSimplexTreeLayer class
Parameters:
- simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n = number of vertices
+ simplextree (gudhi.SimplexTree): underlying simplex tree. Its vertices MUST be named with integers from 0 to n = number of vertices. Note that its filtration values are modified in each call of the class.
dimensions (List[int]): homology dimensions
min_persistence (List[float]): minimum distance-to-diagonal of the points in the output persistence diagrams (default None, in which case 0. is used for all dimensions)
"""
@@ -76,8 +76,8 @@ class LowerStarSimplexTreeLayer(tf.keras.layers.Layer):
essential_dgm = tf.reshape(tf.gather(filtration, indices[idx_dim][1]), [-1,1])
min_pers = self.min_persistence[idx_dim]
if min_pers >= 0:
- persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel()
- self.dgms.append((tf.gather(finite_dgm, indices=persistent_indices), essential_dgm))
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
else:
self.dgms.append((finite_dgm, essential_dgm))
return self.dgms
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py
index 88d501c1..7b5edfa3 100644
--- a/src/python/gudhi/tensorflow/rips_layer.py
+++ b/src/python/gudhi/tensorflow/rips_layer.py
@@ -83,8 +83,8 @@ class RipsLayer(tf.keras.layers.Layer):
essential_dgm = tf.zeros([cur_idx[1].shape[0],1])
min_pers = self.min_persistence[idx_dim]
if min_pers >= 0:
- persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel()
- self.dgms.append((tf.gather(finite_dgm, indices=persistent_indices), essential_dgm))
+ persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
+ self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm))
else:
self.dgms.append((finite_dgm, essential_dgm))
return self.dgms