summaryrefslogtreecommitdiff
path: root/src/python/test/test_dtm.py
blob: 57fdd131bdb60585614430adf88dd402afb73438 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
    See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
    Author(s):       Marc Glisse

    Copyright (C) 2020 Inria

    Modification(s):
      - YYYY/MM Author: Description of the modification
"""

from gudhi.point_cloud.dtm import DTM
import numpy


def test_dtm_euclidean():
    pts = numpy.random.rand(1000,4)
    k = 3
    dtm = DTM(k,implementation="ckdtree")
    print(dtm.fit_transform(pts))
    dtm = DTM(k,implementation="sklearn")
    print(dtm.fit_transform(pts))
    dtm = DTM(k,implementation="sklearn",algorithm="brute")
    print(dtm.fit_transform(pts))
    dtm = DTM(k,implementation="hnsw")
    print(dtm.fit_transform(pts))
    from scipy.spatial.distance import cdist
    d = cdist(pts,pts)
    dtm = DTM(k,metric="precomputed")
    print(dtm.fit_transform(d))
    dtm = DTM(k,implementation="keops")
    print(dtm.fit_transform(pts))