summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/cubical_layer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/tensorflow/cubical_layer.py')
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index 369b0e54..31c44205 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -72,5 +72,5 @@ class CubicalLayer(tf.keras.layers.Layer):
if min_pers >= 0:
finite_dgm = self.dgms[idx_dim]
persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers)
- self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices)
+ self.dgms[idx_dim] = tf.reshape(tf.gather(finite_dgm, indices=persistent_indices), [-1,2])
return self.dgms