summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/rips_layer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/tensorflow/rips_layer.py')
-rw-r--r--src/python/gudhi/tensorflow/rips_layer.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py
index a5f212e3..88d501c1 100644
--- a/src/python/gudhi/tensorflow/rips_layer.py
+++ b/src/python/gudhi/tensorflow/rips_layer.py
@@ -40,7 +40,7 @@ class RipsLayer(tf.keras.layers.Layer):
"""
TensorFlow layer for computing Rips persistence out of a point cloud
"""
- def __init__(self, dimensions, maximum_edge_length=12, min_persistence=None, **kwargs):
+ def __init__(self, dimensions, maximum_edge_length=np.inf, min_persistence=None, **kwargs):
"""
Constructor for the RipsLayer class
@@ -66,7 +66,7 @@ class RipsLayer(tf.keras.layers.Layer):
dgms (list of tuple of TensorFlow variables): list of Rips persistence diagrams of length self.dimensions, where each element of the list is a tuple that contains the finite and essential persistence diagrams of shapes [num_finite_points, 2] and [num_essential_points, 1] respectively
"""
# Compute distance matrix
- DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2))
+ DX = tf.norm(tf.expand_dims(X, 1)-tf.expand_dims(X, 0), axis=2)
# Compute vertices associated to positive and negative simplices
# Don't compute gradient for this operation
indices = _Rips(DX.numpy(), self.max_edge, self.dimensions)