summaryrefslogtreecommitdiff
path: root/src/python/gudhi/weighted_rips_complex.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/weighted_rips_complex.py')
-rw-r--r--src/python/gudhi/weighted_rips_complex.py35
1 files changed, 15 insertions, 20 deletions
diff --git a/src/python/gudhi/weighted_rips_complex.py b/src/python/gudhi/weighted_rips_complex.py
index 7e504b2c..83fa82c5 100644
--- a/src/python/gudhi/weighted_rips_complex.py
+++ b/src/python/gudhi/weighted_rips_complex.py
@@ -11,34 +11,29 @@ from gudhi import SimplexTree
class WeightedRipsComplex:
"""
- class to generate a weighted Rips complex
- from a distance matrix and weights on vertices
+ Class to generate a weighted Rips complex from a distance matrix and weights on vertices.
"""
def __init__(self,
distance_matrix,
- weights=None,
+ weights="diagonal",
max_filtration=float('inf')):
"""
- Parameters:
- distance_matrix: list of list of float,
- distance matrix (full square or lower triangular)
- filtration_values: list of float,
- weight for each vertex
- max_filtration: float,
- specifies the maximal filtration value to be considered
+ Args:
+ distance_matrix (list of list of float): distance matrix (full square or lower triangular).
+ weights (list of float): (one half of) weight for each vertex.
+ max_filtration (float): specifies the maximal filtration value to be considered.
"""
self.distance_matrix = distance_matrix
- if weights is not None:
- self.weights = weights
+ if weights == "diagonal":
+ self.weights = [distance_matrix[i][i] for i in range(len(distance_matrix))]
else:
- self.weights = [0] * len(distance_matrix)
+ self.weights = weights
self.max_filtration = max_filtration
def create_simplex_tree(self, max_dimension):
"""
- Parameter:
- max_dimension: int
- graph expansion until this given dimension
+ Args:
+ max_dimension (int): graph expansion until this given dimension.
"""
dist = self.distance_matrix
F = self.weights
@@ -47,12 +42,12 @@ class WeightedRipsComplex:
st = SimplexTree()
for i in range(num_pts):
- if F[i] < self.max_filtration:
- st.insert([i], F[i])
+ if 2*F[i] <= self.max_filtration:
+ st.insert([i], 2*F[i])
for i in range(num_pts):
for j in range(i):
- value = max(F[i], F[j], (dist[i][j] + F[i] + F[j]) / 2)
- if value < self.max_filtration:
+ value = max(2*F[i], 2*F[j], dist[i][j] + F[i] + F[j])
+ if value <= self.max_filtration:
st.insert([i,j], filtration=value)
st.expansion(max_dimension)