blob: 83fa82c54afd66fa32dde50c69968e14135549fb (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s): Raphaël Tinarrage, Yuichi Ike, Masatoshi Takenouchi
#
# Copyright (C) 2020 Inria, Copyright (C) 2020 FUjitsu Laboratories Ltd.
#
# Modification(s):
# - YYYY/MM Author: Description of the modification
from gudhi import SimplexTree
class WeightedRipsComplex:
"""
Class to generate a weighted Rips complex from a distance matrix and weights on vertices.
"""
def __init__(self,
distance_matrix,
weights="diagonal",
max_filtration=float('inf')):
"""
Args:
distance_matrix (list of list of float): distance matrix (full square or lower triangular).
weights (list of float): (one half of) weight for each vertex.
max_filtration (float): specifies the maximal filtration value to be considered.
"""
self.distance_matrix = distance_matrix
if weights == "diagonal":
self.weights = [distance_matrix[i][i] for i in range(len(distance_matrix))]
else:
self.weights = weights
self.max_filtration = max_filtration
def create_simplex_tree(self, max_dimension):
"""
Args:
max_dimension (int): graph expansion until this given dimension.
"""
dist = self.distance_matrix
F = self.weights
num_pts = len(dist)
st = SimplexTree()
for i in range(num_pts):
if 2*F[i] <= self.max_filtration:
st.insert([i], 2*F[i])
for i in range(num_pts):
for j in range(i):
value = max(2*F[i], 2*F[j], dist[i][j] + F[i] + F[j])
if value <= self.max_filtration:
st.insert([i,j], filtration=value)
st.expansion(max_dimension)
return st
|