summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-12 12:34:43 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-12 12:34:43 +0100
commit1fd37bf29d665330f1eb242139bc0faf10a542c1 (patch)
tree021dccf630c74ffb1e68de9b075c6f1bfb53df7b /src/python/gudhi
parent74dfd101312a48272f2f91c3ddc401d1148deaec (diff)
avoid transpose
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index 70528f98..0971a446 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -8,13 +8,14 @@ from ..cubical_complex import CubicalComplex
# The parameters of the model are the pixel values.
-def _Cubical(X, dimensions):
- # Parameters: X (image),
+def _Cubical(Xflat, Xdim, dimensions):
+ # Parameters: Xflat (flattened image),
+ # Xdim (shape of non-flattened image)
# dimensions (homology dimensions)
# Compute the persistence pairs with Gudhi
- Xs = X.shape
- cc = CubicalComplex(top_dimensional_cells=X)
+ # We reverse the dimensions because CubicalComplex uses Fortran ordering
+ cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat)
cc.compute_persistence()
# Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
@@ -61,7 +62,9 @@ class CubicalLayer(tf.keras.layers.Layer):
"""
# Compute pixels associated to positive and negative simplices
# Don't compute gradient for this operation
- indices = _Cubical(X.numpy(), self.dimensions)
+ Xflat = tf.reshape(X, [-1])
+ Xdim = X.shape
+ indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
# Get persistence diagram by simply picking the corresponding entries in the image
- self.dgms = [tf.reshape(tf.gather( tf.reshape(tf.transpose(X), [-1]), indice ), [-1,2]) for indice in indices]
+ self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices]
return self.dgms