1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
import numpy as np
import tensorflow as tf
from ..rips_complex import RipsComplex
############################
# Vietoris-Rips filtration #
############################
# The parameters of the model are the point coordinates.
def _Rips(DX, max_edge, dimension):
# Parameters: DX (distance matrix),
# max_edge (maximum edge length for Rips filtration),
# dimension (homology dimension)
# Compute the persistence pairs with Gudhi
rc = RipsComplex(distance_matrix=DX, max_edge_length=max_edge)
st = rc.create_simplex_tree(max_dimension=dimension+1)
dgm = st.persistence()
pairs = st.persistence_pairs()
# Retrieve vertices v_a and v_b by picking the ones achieving the maximal
# distance among all pairwise distances between the simplex vertices
indices, pers = [], []
for s1, s2 in pairs:
if len(s1) == dimension+1 and len(s2) > 0:
l1, l2 = np.array(s1), np.array(s2)
i1 = [l1[v] for v in np.unravel_index(np.argmax(DX[l1,:][:,l1]),[len(l1), len(l1)])]
i2 = [l2[v] for v in np.unravel_index(np.argmax(DX[l2,:][:,l2]),[len(l2), len(l2)])]
indices.append(i1)
indices.append(i2)
pers.append(st.filtration(s2)-st.filtration(s1))
# Sort points with distance-to-diagonal
perm = np.argsort(pers)
indices = np.reshape(indices, [-1,4])[perm][::-1,:].flatten()
return np.array(indices, dtype=np.int32)
class RipsLayer(tf.keras.layers.Layer):
"""
TensorFlow layer for computing Rips persistence out of a point cloud
Attributes:
maximum_edge_length (float): maximum edge length for the Rips complex
dimension (int): homology dimension
"""
def __init__(self, maximum_edge_length=12, dimension=1, **kwargs):
super().__init__(dynamic=True, **kwargs)
self.max_edge = maximum_edge_length
self.dimension = dimension
def build(self):
super.build()
def call(self, X):
"""
Compute Rips persistence diagram associated to a point cloud
Parameters:
X (TensorFlow variable): point cloud of shape [number of points, number of dimensions]
"""
# Compute distance matrix
DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2))
# Compute vertices associated to positive and negative simplices
# Don't compute gradient for this operation
indices = tf.stop_gradient(_Rips(DX.numpy(), self.max_edge, self.dimension))
# Get persistence diagram by simply picking the corresponding entries in the distance matrix
if self.dimension > 0:
dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(indices, [-1,2])), [-1,2])
else:
indices = tf.reshape(indices, [-1,2])[1::2,:]
dgm = tf.concat([tf.zeros([indices.shape[0],1]), tf.reshape(tf.gather_nd(DX, indices), [-1,1])], axis=1)
return dgm
|