diff options
Diffstat (limited to 'src/python/gudhi/weighted_rips_complex.py')
-rw-r--r-- | src/python/gudhi/weighted_rips_complex.py | 35 |
1 files changed, 15 insertions, 20 deletions
diff --git a/src/python/gudhi/weighted_rips_complex.py b/src/python/gudhi/weighted_rips_complex.py index 7e504b2c..83fa82c5 100644 --- a/src/python/gudhi/weighted_rips_complex.py +++ b/src/python/gudhi/weighted_rips_complex.py @@ -11,34 +11,29 @@ from gudhi import SimplexTree class WeightedRipsComplex: """ - class to generate a weighted Rips complex - from a distance matrix and weights on vertices + Class to generate a weighted Rips complex from a distance matrix and weights on vertices. """ def __init__(self, distance_matrix, - weights=None, + weights="diagonal", max_filtration=float('inf')): """ - Parameters: - distance_matrix: list of list of float, - distance matrix (full square or lower triangular) - filtration_values: list of float, - weight for each vertex - max_filtration: float, - specifies the maximal filtration value to be considered + Args: + distance_matrix (list of list of float): distance matrix (full square or lower triangular). + weights (list of float): (one half of) weight for each vertex. + max_filtration (float): specifies the maximal filtration value to be considered. """ self.distance_matrix = distance_matrix - if weights is not None: - self.weights = weights + if weights == "diagonal": + self.weights = [distance_matrix[i][i] for i in range(len(distance_matrix))] else: - self.weights = [0] * len(distance_matrix) + self.weights = weights self.max_filtration = max_filtration def create_simplex_tree(self, max_dimension): """ - Parameter: - max_dimension: int - graph expansion until this given dimension + Args: + max_dimension (int): graph expansion until this given dimension. """ dist = self.distance_matrix F = self.weights @@ -47,12 +42,12 @@ class WeightedRipsComplex: st = SimplexTree() for i in range(num_pts): - if F[i] < self.max_filtration: - st.insert([i], F[i]) + if 2*F[i] <= self.max_filtration: + st.insert([i], 2*F[i]) for i in range(num_pts): for j in range(i): - value = max(F[i], F[j], (dist[i][j] + F[i] + F[j]) / 2) - if value < self.max_filtration: + value = max(2*F[i], 2*F[j], dist[i][j] + F[i] + F[j]) + if value <= self.max_filtration: st.insert([i,j], filtration=value) st.expansion(max_dimension) |