summaryrefslogtreecommitdiff
path: root/src/python/gudhi/tensorflow/cubical_layer.py
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-12 09:46:22 +0100
committerMathieuCarriere <mathieu.carriere3@gmail.com>2021-11-12 09:46:22 +0100
commit6ae793a8cad4503d1795e227d40d85d43954d1dd (patch)
treee46c9cc11628456739008f281a860a4df1d20775 /src/python/gudhi/tensorflow/cubical_layer.py
parent3f1a6e659611dce2913fddc93b01480f05fb7983 (diff)
removed unraveling in cubical
Diffstat (limited to 'src/python/gudhi/tensorflow/cubical_layer.py')
-rw-r--r--src/python/gudhi/tensorflow/cubical_layer.py15
1 files changed, 4 insertions, 11 deletions
diff --git a/src/python/gudhi/tensorflow/cubical_layer.py b/src/python/gudhi/tensorflow/cubical_layer.py
index 55bd2685..70528f98 100644
--- a/src/python/gudhi/tensorflow/cubical_layer.py
+++ b/src/python/gudhi/tensorflow/cubical_layer.py
@@ -17,6 +17,7 @@ def _Cubical(X, dimensions):
cc = CubicalComplex(top_dimensional_cells=X)
cc.compute_persistence()
+ # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
cof_pp = cc.cofaces_of_persistence_pairs()
L_cofs = []
@@ -27,15 +28,7 @@ def _Cubical(X, dimensions):
except IndexError:
cof = np.array([])
- # Retrieve and ouput image indices/pixels corresponding to positive and negative simplices
- D = len(Xs) if len(cof) > 0 else 1
- ocof = np.zeros(D*2*cof.shape[0])
- count = 0
- for idx in range(0,2*cof.shape[0],2):
- ocof[D*idx:D*(idx+1)] = np.unravel_index(cof[count,0], Xs, order='F')
- ocof[D*(idx+1):D*(idx+2)] = np.unravel_index(cof[count,1], Xs, order='F')
- count += 1
- L_cofs.append(np.array(ocof, dtype=np.int32))
+ L_cofs.append(np.array(cof, dtype=np.int32))
return L_cofs
@@ -43,7 +36,7 @@ class CubicalLayer(tf.keras.layers.Layer):
"""
TensorFlow layer for computing cubical persistence out of a cubical complex
"""
- def __init__(self, dimensions=[0], **kwargs):
+ def __init__(self, dimensions, **kwargs):
"""
Constructor for the CubicalLayer class
@@ -70,5 +63,5 @@ class CubicalLayer(tf.keras.layers.Layer):
# Don't compute gradient for this operation
indices = _Cubical(X.numpy(), self.dimensions)
# Get persistence diagram by simply picking the corresponding entries in the image
- self.dgms = [tf.reshape(tf.gather_nd(X, tf.reshape(indice, [-1,len(X.shape)])), [-1,2]) for indice in indices]
+ self.dgms = [tf.reshape(tf.gather( tf.reshape(tf.transpose(X), [-1]), indice ), [-1,2]) for indice in indices]
return self.dgms