summaryrefslogtreecommitdiff
path: root/src/python/gudhi/dtm_rips_complex.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/dtm_rips_complex.py')
-rw-r--r--src/python/gudhi/dtm_rips_complex.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/python/gudhi/dtm_rips_complex.py b/src/python/gudhi/dtm_rips_complex.py
index 6d2f9f31..70c8e5dd 100644
--- a/src/python/gudhi/dtm_rips_complex.py
+++ b/src/python/gudhi/dtm_rips_complex.py
@@ -12,7 +12,7 @@ from gudhi.weighted_rips_complex import WeightedRipsComplex
from gudhi.point_cloud.dtm import DistanceToMeasure
from scipy.spatial.distance import cdist
-class DtmRipsComplex(WeightedRipsComplex):
+class DTMRipsComplex(WeightedRipsComplex):
"""
Class to generate a DTM Rips complex from a distance matrix or a point set,
in the way described in :cite:`dtmfiltrations`.
@@ -28,7 +28,7 @@ class DtmRipsComplex(WeightedRipsComplex):
"""
Args:
points (Sequence[Sequence[float]]): list of points.
- distance_matrix (ndarray): full distance matrix.
+ distance_matrix (numpy.ndarray): full distance matrix.
k (int): number of neighbors for the computation of DTM. Defaults to 1, which is equivalent to the usual Rips complex.
q (float): order used to compute the distance to measure. Defaults to 2.
max_filtration (float): specifies the maximal filtration value to be considered.
@@ -39,8 +39,12 @@ class DtmRipsComplex(WeightedRipsComplex):
points=[]
distance_matrix = cdist(points,points)
self.distance_matrix = distance_matrix
- dtm = DistanceToMeasure(k, q=q, metric="precomputed")
+
# TODO: address the error when k is too large
- self.weights = dtm.fit_transform(distance_matrix)
+ if k <= 1:
+ self.weights = [0] * len(distance_matrix)
+ else:
+ dtm = DistanceToMeasure(k, q=q, metric="precomputed")
+ self.weights = dtm.fit_transform(distance_matrix)
self.max_filtration = max_filtration