summaryrefslogtreecommitdiff
path: root/src/python/test/test_dtm.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-27 13:43:58 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-27 13:43:58 +0100
commitf74c71ca8e474ff927cae029ea63329d30293582 (patch)
tree4696a751c330e2c5a4bc710c28ba72bcee0579f6 /src/python/test/test_dtm.py
parentaf35ea5b4ce631ae826f1db1940798f254aba658 (diff)
Improve coverage
Diffstat (limited to 'src/python/test/test_dtm.py')
-rwxr-xr-xsrc/python/test/test_dtm.py48
1 files changed, 33 insertions, 15 deletions
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 57fdd131..841f8c3c 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -10,23 +10,41 @@
from gudhi.point_cloud.dtm import DTM
import numpy
+import pytest
-def test_dtm_euclidean():
- pts = numpy.random.rand(1000,4)
+def test_dtm_compare_euclidean():
+ pts = numpy.random.rand(1000, 4)
k = 3
- dtm = DTM(k,implementation="ckdtree")
- print(dtm.fit_transform(pts))
- dtm = DTM(k,implementation="sklearn")
- print(dtm.fit_transform(pts))
- dtm = DTM(k,implementation="sklearn",algorithm="brute")
- print(dtm.fit_transform(pts))
- dtm = DTM(k,implementation="hnsw")
- print(dtm.fit_transform(pts))
+ dtm = DTM(k, implementation="ckdtree")
+ r0 = dtm.fit_transform(pts)
+ dtm = DTM(k, implementation="sklearn")
+ r1 = dtm.fit_transform(pts)
+ assert r1 == pytest.approx(r0)
+ dtm = DTM(k, implementation="sklearn", algorithm="brute")
+ r2 = dtm.fit_transform(pts)
+ assert r2 == pytest.approx(r0)
+ dtm = DTM(k, implementation="hnsw")
+ r3 = dtm.fit_transform(pts)
+ assert r3 == pytest.approx(r0)
from scipy.spatial.distance import cdist
- d = cdist(pts,pts)
- dtm = DTM(k,metric="precomputed")
- print(dtm.fit_transform(d))
- dtm = DTM(k,implementation="keops")
- print(dtm.fit_transform(pts))
+ d = cdist(pts, pts)
+ dtm = DTM(k, metric="precomputed")
+ r4 = dtm.fit_transform(d)
+ assert r4 == pytest.approx(r0)
+ dtm = DTM(k, implementation="keops")
+ r5 = dtm.fit_transform(pts)
+ assert r5 == pytest.approx(r0)
+
+
+def test_dtm_precomputed():
+ dist = numpy.array([[1.0, 3, 8], [1, 5, 5], [0, 2, 3]])
+ dtm = DTM(2, q=1, metric="neighbors")
+ r = dtm.fit_transform(dist)
+ assert r == pytest.approx([2.0, 3, 1])
+
+ dist = numpy.array([[2.0, 2], [0, 1], [3, 4]])
+ dtm = DTM(2, q=2, metric="neighbors")
+ r = dtm.fit_transform(dist)
+ assert r == pytest.approx([2.0, .707, 3.5355], rel=.01)