diff options
Diffstat (limited to 'src/python/gudhi/tensorflow/rips_layer.py')
-rw-r--r-- | src/python/gudhi/tensorflow/rips_layer.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/python/gudhi/tensorflow/rips_layer.py b/src/python/gudhi/tensorflow/rips_layer.py index 88d501c1..7b5edfa3 100644 --- a/src/python/gudhi/tensorflow/rips_layer.py +++ b/src/python/gudhi/tensorflow/rips_layer.py @@ -83,8 +83,8 @@ class RipsLayer(tf.keras.layers.Layer): essential_dgm = tf.zeros([cur_idx[1].shape[0],1]) min_pers = self.min_persistence[idx_dim] if min_pers >= 0: - persistent_indices = np.argwhere(np.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers).ravel() - self.dgms.append((tf.gather(finite_dgm, indices=persistent_indices), essential_dgm)) + persistent_indices = tf.where(tf.math.abs(finite_dgm[:,1]-finite_dgm[:,0]) > min_pers) + self.dgms.append((tf.reshape(tf.gather(finite_dgm, indices=persistent_indices),[-1,2]), essential_dgm)) else: self.dgms.append((finite_dgm, essential_dgm)) return self.dgms |