diff options
author | MathieuCarriere <mathieu.carriere3@gmail.com> | 2021-11-12 12:34:43 +0100 |
---|---|---|
committer | MathieuCarriere <mathieu.carriere3@gmail.com> | 2021-11-12 12:34:43 +0100 |
commit | 1fd37bf29d665330f1eb242139bc0faf10a542c1 (patch) | |
tree | 021dccf630c74ffb1e68de9b075c6f1bfb53df7b /src/python | |
parent | 74dfd101312a48272f2f91c3ddc401d1148deaec (diff) |
avoid transpose
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/gudhi/tensorflow/cubical_layer.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py index 70528f98..0971a446 100644 --- a/src/python/gudhi/tensorflow/cubical_layer.py +++ b/src/python/gudhi/tensorflow/cubical_layer.py @@ -8,13 +8,14 @@ from ..cubical_complex import CubicalComplex # The parameters of the model are the pixel values. -def _Cubical(X, dimensions): - # Parameters: X (image), +def _Cubical(Xflat, Xdim, dimensions): + # Parameters: Xflat (flattened image), + # Xdim (shape of non-flattened image) # dimensions (homology dimensions) # Compute the persistence pairs with Gudhi - Xs = X.shape - cc = CubicalComplex(top_dimensional_cells=X) + # We reverse the dimensions because CubicalComplex uses Fortran ordering + cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat) cc.compute_persistence() # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices @@ -61,7 +62,9 @@ class CubicalLayer(tf.keras.layers.Layer): """ # Compute pixels associated to positive and negative simplices # Don't compute gradient for this operation - indices = _Cubical(X.numpy(), self.dimensions) + Xflat = tf.reshape(X, [-1]) + Xdim = X.shape + indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions) # Get persistence diagram by simply picking the corresponding entries in the image - self.dgms = [tf.reshape(tf.gather( tf.reshape(tf.transpose(X), [-1]), indice ), [-1,2]) for indice in indices] + self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices] return self.dgms |