summaryrefslogtreecommitdiff
path: root/src/python/test/test_knn.py
blob: e455fb48f64f14bf845d9e5c9f2507bb9b9b83d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
    See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
    Author(s):       Marc Glisse

    Copyright (C) 2020 Inria

    Modification(s):
      - YYYY/MM Author: Description of the modification
"""

from gudhi.point_cloud.knn import KNN
import numpy as np
import pytest


def test_knn_explicit():
    base = np.array([[1.0, 1], [1, 2], [4, 2], [4, 3]])
    query = np.array([[1.0, 1], [2, 2], [4, 4]])
    knn = KNN(2, metric="manhattan", return_distance=True, return_index=True)
    knn.fit(base)
    r = knn.transform(query)
    assert r[0] == pytest.approx(np.array([[0, 1], [1, 0], [3, 2]]))
    assert r[1] == pytest.approx(np.array([[0.0, 1], [1, 2], [1, 2]]))

    knn = KNN(2, metric="chebyshev", return_distance=True, return_index=False)
    knn.fit(base)
    r = knn.transform(query)
    assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
    r = (
        KNN(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops")
        .fit(base)
        .transform(query)
    )
    assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))

    knn = KNN(2, metric="minkowski", p=3, return_distance=False, return_index=True)
    knn.fit(base)
    r = knn.transform(query)
    assert np.array_equal(r, [[0, 1], [1, 0], [3, 2]])
    r = (
        KNN(2, metric="minkowski", p=3, return_distance=False, return_index=True, implementation="keops")
        .fit(base)
        .transform(query)
    )
    assert np.array_equal(r, [[0, 1], [1, 0], [3, 2]])

    dist = np.array([[0.0, 3, 8], [1, 0, 5], [1, 2, 0]])
    knn = KNN(2, metric="precomputed", return_index=True, return_distance=False)
    r = knn.fit_transform(dist)
    assert np.array_equal(r, [[0, 1], [1, 0], [2, 0]])
    knn = KNN(2, metric="precomputed", return_index=True, return_distance=True)
    r = knn.fit_transform(dist)
    assert np.array_equal(r[0], [[0, 1], [1, 0], [2, 0]])
    assert np.array_equal(r[1], [[0, 3], [0, 1], [0, 1]])


def test_knn_compare():
    base = np.array([[1.0, 1], [1, 2], [4, 2], [4, 3]])
    query = np.array([[1.0, 1], [2, 2], [4, 4]])
    r0 = KNN(2, implementation="ckdtree", return_index=True, return_distance=False).fit(base).transform(query)
    r1 = KNN(2, implementation="sklearn", return_index=True, return_distance=False).fit(base).transform(query)
    r2 = KNN(2, implementation="hnsw", return_index=True, return_distance=False).fit(base).transform(query)
    r3 = KNN(2, implementation="keops", return_index=True, return_distance=False).fit(base).transform(query)
    assert np.array_equal(r0, r1) and np.array_equal(r0, r2) and np.array_equal(r0, r3)

    r0 = KNN(2, implementation="ckdtree", return_index=True, return_distance=True).fit(base).transform(query)
    r1 = KNN(2, implementation="sklearn", return_index=True, return_distance=True).fit(base).transform(query)
    r2 = KNN(2, implementation="hnsw", return_index=True, return_distance=True).fit(base).transform(query)
    r3 = KNN(2, implementation="keops", return_index=True, return_distance=True).fit(base).transform(query)
    assert np.array_equal(r0[0], r1[0]) and np.array_equal(r0[0], r2[0]) and np.array_equal(r0[0], r3[0])
    d0 = pytest.approx(r0[1])
    assert r1[1] == d0 and r2[1] == d0 and r3[1] == d0


def test_knn_nop():
    # This doesn't look super useful...
    p = np.array([[0.0]])
    assert None is KNN(k=1, return_index=False, return_distance=False, implementation="sklearn").fit_transform(p)
    assert None is KNN(k=1, return_index=False, return_distance=False, implementation="ckdtree").fit_transform(p)
    assert None is KNN(k=1, return_index=False, return_distance=False, implementation="hnsw", ef=5).fit_transform(p)
    assert None is KNN(k=1, return_index=False, return_distance=False, implementation="keops").fit_transform(p)
    assert None is KNN(k=1, return_index=False, return_distance=False, metric="precomputed").fit_transform(p)