diff options
Diffstat (limited to 'src/python/gudhi/weighted_rips_complex.py')
-rw-r--r-- | src/python/gudhi/weighted_rips_complex.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/src/python/gudhi/weighted_rips_complex.py b/src/python/gudhi/weighted_rips_complex.py new file mode 100644 index 00000000..83fa82c5 --- /dev/null +++ b/src/python/gudhi/weighted_rips_complex.py @@ -0,0 +1,55 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Raphaƫl Tinarrage, Yuichi Ike, Masatoshi Takenouchi +# +# Copyright (C) 2020 Inria, Copyright (C) 2020 FUjitsu Laboratories Ltd. +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +from gudhi import SimplexTree + +class WeightedRipsComplex: + """ + Class to generate a weighted Rips complex from a distance matrix and weights on vertices. + """ + def __init__(self, + distance_matrix, + weights="diagonal", + max_filtration=float('inf')): + """ + Args: + distance_matrix (list of list of float): distance matrix (full square or lower triangular). + weights (list of float): (one half of) weight for each vertex. + max_filtration (float): specifies the maximal filtration value to be considered. + """ + self.distance_matrix = distance_matrix + if weights == "diagonal": + self.weights = [distance_matrix[i][i] for i in range(len(distance_matrix))] + else: + self.weights = weights + self.max_filtration = max_filtration + + def create_simplex_tree(self, max_dimension): + """ + Args: + max_dimension (int): graph expansion until this given dimension. + """ + dist = self.distance_matrix + F = self.weights + num_pts = len(dist) + + st = SimplexTree() + + for i in range(num_pts): + if 2*F[i] <= self.max_filtration: + st.insert([i], 2*F[i]) + for i in range(num_pts): + for j in range(i): + value = max(2*F[i], 2*F[j], dist[i][j] + F[i] + F[j]) + if value <= self.max_filtration: + st.insert([i,j], filtration=value) + + st.expansion(max_dimension) + return st + |