summaryrefslogtreecommitdiff
path: root/src/python/gudhi/dtm_rips_complex.py
blob: 63c9b1387e24d68d1e874db0599d807f156a65ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s):       Yuichi Ike, Raphaël Tinarrage
#
# Copyright (C) 2020 Inria, Copyright (C) 2020 FUjitsu Laboratories Ltd.
#
# Modification(s):
#   - YYYY/MM Author: Description of the modification


from gudhi.weighted_rips_complex import WeightedRipsComplex
from gudhi.point_cloud.dtm import DistanceToMeasure
from scipy.spatial.distance import cdist

class DTMRipsComplex(WeightedRipsComplex):
    """
    Class to generate a DTM Rips complex from a distance matrix or a point set, 
    in the way described in :cite:`dtmfiltrations`.
    Remark that all the filtration values are doubled compared to the definition in the paper 
    for the consistency with RipsComplex.
    :Requires: `SciPy <installation.html#scipy>`_
    """
    def __init__(self, 
                 points=None, 
                 distance_matrix=None, 
                 k=1, 
                 q=2,
                 max_filtration=float('inf')):
        """
        Args:
            points (numpy.ndarray): array of points.
            distance_matrix (numpy.ndarray): full distance matrix.
            k (int): number of neighbors for the computation of DTM. Defaults to 1, which is equivalent to the usual Rips complex.
            q (float): order used to compute the distance to measure. Defaults to 2.
            max_filtration (float): specifies the maximal filtration value to be considered.      
        """
        if distance_matrix is None:
            if points is None:
                # Empty Rips construction
                points=[]
            distance_matrix = cdist(points,points)
        self.distance_matrix = distance_matrix

        # TODO: address the error when k is too large 
        if k <= 1:
            self.weights = [0] * len(distance_matrix)
        else:
            dtm = DistanceToMeasure(k, q=q, metric="precomputed")        
            self.weights = dtm.fit_transform(distance_matrix)
        self.max_filtration = max_filtration