From 27f8df308e3ed935e4ef9f62d23717efebdf36ae Mon Sep 17 00:00:00 2001 From: MathieuCarriere Date: Tue, 12 Apr 2022 15:21:02 +0200 Subject: fix doc + reshape in cubical --- src/python/doc/cubical_complex_tflow_itf_ref.rst | 2 +- src/python/gudhi/tensorflow/cubical_layer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'src/python') diff --git a/src/python/doc/cubical_complex_tflow_itf_ref.rst b/src/python/doc/cubical_complex_tflow_itf_ref.rst index 18b97adf..881a2950 100644 --- a/src/python/doc/cubical_complex_tflow_itf_ref.rst +++ b/src/python/doc/cubical_complex_tflow_itf_ref.rst @@ -19,7 +19,7 @@ Example of gradient computed from cubical persistence cl = CubicalLayer(dimensions=[0]) with tf.GradientTape() as tape: - dgm = cl.call(X)[0][0] + dgm = cl.call(X)[0] loss = tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) grads = tape.gradient(loss, [X]) diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py index 369b0e54..31c44205 100644 --- a/src/python/gudhi/tensorflow/cubical_layer.py +++ b/src/python/gudhi/tensorflow/cubical_layer.py @@ -72,5 +72,5 @@ class CubicalLayer(tf.keras.layers.Layer): if min_pers >= 0: finite_dgm = self.dgms[idx_dim] persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) - self.dgms[idx_dim] = tf.gather(finite_dgm, indices=persistent_indices) + self.dgms[idx_dim] = tf.reshape(tf.gather(finite_dgm, indices=persistent_indices), [-1,2]) return self.dgms -- cgit v1.2.3