summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/cubical_layer.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2021-12-04 12:41:59 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2021-12-04 12:41:59 +0100
commit979d12e00b4ea71391d132589ee3304e378459b9 (patch)
treed2a6c8e3c8d344ee5d860cfad775fbef5d10b43d /src/python/gudhi/tensorflow/cubical_layer.py
parente9b297ec86d79e2b5b2fd4ce63033f8697f053da (diff)
added min persistence
Diffstat (limited to 'src/python/gudhi/tensorflow/cubical_layer.py')
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index d07a4cd8..8fe9cff0 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -8,7 +8,7 @@ from ..cubical_complex import CubicalComplex
# The parameters of the model are the pixel values.
-def _Cubical(Xflat, Xdim, dimensions):
+def _Cubical(Xflat, Xdim, dimensions, min_persistence):
# Parameters: Xflat (flattened image),
# Xdim (shape of non-flattened image)
# dimensions (homology dimensions)
@@ -16,7 +16,7 @@ def _Cubical(Xflat, Xdim, dimensions):
# Compute the persistence pairs with Gudhi
# We reverse the dimensions because CubicalComplex uses Fortran ordering
cc = CubicalComplex(dimensions=Xdim[::-1], top_dimensional_cells=Xflat)
- cc.compute_persistence()
+ cc.compute_persistence(min_persistence=min_persistence)
# Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
cof_pp = cc.cofaces_of_persistence_pairs()
@@ -37,7 +37,7 @@ class CubicalLayer(tf.keras.layers.Layer):
"""
TensorFlow layer for computing cubical persistence out of a cubical complex
"""
- def __init__(self, dimensions, **kwargs):
+ def __init__(self, dimensions, min_persistence=0., **kwargs):
"""
Constructor for the CubicalLayer class
@@ -46,9 +46,7 @@ class CubicalLayer(tf.keras.layers.Layer):
"""
super().__init__(dynamic=True, **kwargs)
self.dimensions = dimensions
-
- def build(self):
- super.build()
+ self.min_persistence = min_persistence
def call(self, X):
"""
@@ -64,7 +62,7 @@ class CubicalLayer(tf.keras.layers.Layer):
# Don't compute gradient for this operation
Xflat = tf.reshape(X, [-1])
Xdim = X.shape
- indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions)
+ indices = _Cubical(Xflat.numpy(), Xdim, self.dimensions, self.min_persistence)
# Get persistence diagram by simply picking the corresponding entries in the image
self.dgms = [tf.reshape(tf.gather(Xflat, indice), [-1,2]) for indice in indices]
return self.dgms