summaryrefslogtreecommitdiff
path: root/src/python/gudhi/dtm_rips_complex.py
diff options
context:
space:
mode:
authoryuichi-ike <yuichi.ike.1990@gmail.com>2020-05-22 10:22:31 +0900
committeryuichi-ike <yuichi.ike.1990@gmail.com>2020-05-22 10:22:31 +0900
commit2ccc5ea97a5979f80fec93863da5549e4e6f2eea (patch)
tree187d8a3b84d144755f63d69bd9581f4d662eca35 /src/python/gudhi/dtm_rips_complex.py
parentc4e93ba5f1d003c442e3d56d6a0b3e80651dd6ec (diff)
class name changed, documents modified
Diffstat (limited to 'src/python/gudhi/dtm_rips_complex.py')
-rw-r--r--src/python/gudhi/dtm_rips_complex.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/python/gudhi/dtm_rips_complex.py b/src/python/gudhi/dtm_rips_complex.py
index 6d2f9f31..70c8e5dd 100644
--- a/src/python/gudhi/dtm_rips_complex.py
+++ b/src/python/gudhi/dtm_rips_complex.py
@@ -12,7 +12,7 @@ from gudhi.weighted_rips_complex import WeightedRipsComplex
from gudhi.point_cloud.dtm import DistanceToMeasure
from scipy.spatial.distance import cdist
-class DtmRipsComplex(WeightedRipsComplex):
+class DTMRipsComplex(WeightedRipsComplex):
"""
Class to generate a DTM Rips complex from a distance matrix or a point set,
in the way described in :cite:`dtmfiltrations`.
@@ -28,7 +28,7 @@ class DtmRipsComplex(WeightedRipsComplex):
"""
Args:
points (Sequence[Sequence[float]]): list of points.
- distance_matrix (ndarray): full distance matrix.
+ distance_matrix (numpy.ndarray): full distance matrix.
k (int): number of neighbors for the computation of DTM. Defaults to 1, which is equivalent to the usual Rips complex.
q (float): order used to compute the distance to measure. Defaults to 2.
max_filtration (float): specifies the maximal filtration value to be considered.
@@ -39,8 +39,12 @@ class DtmRipsComplex(WeightedRipsComplex):
points=[]
distance_matrix = cdist(points,points)
self.distance_matrix = distance_matrix
- dtm = DistanceToMeasure(k, q=q, metric="precomputed")
+
# TODO: address the error when k is too large
- self.weights = dtm.fit_transform(distance_matrix)
+ if k <= 1:
+ self.weights = [0] * len(distance_matrix)
+ else:
+ dtm = DistanceToMeasure(k, q=q, metric="precomputed")
+ self.weights = dtm.fit_transform(distance_matrix)
self.max_filtration = max_filtration