summaryrefslogtreecommitdiff
path: root/src/python/test/test_knn.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:09:45 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-13 01:10:05 +0200
commit280eb9d2323837619db1ae013b929adb9b45013b (patch)
tree19913f2113c38e9c7783cf5b186dd7a34dd921b2 /src/python/test/test_knn.py
parent83a1bc1fb6124a35d515f4836d2e830f3dbdf0e7 (diff)
enable_autodiff with keops
This doesn't seem like the best way to handle it, we may want to handle it like a wrapper that gets the indices from knn (whatever backend) and then computes the distances.
Diffstat (limited to 'src/python/test/test_knn.py')
-rwxr-xr-xsrc/python/test/test_knn.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py
index 6269df54..415c9d48 100755
--- a/src/python/test/test_knn.py
+++ b/src/python/test/test_knn.py
@@ -32,6 +32,12 @@ def test_knn_explicit():
.transform(query)
)
assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
+ r = (
+ KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops", enable_autodiff=True)
+ .fit(base)
+ .transform(query)
+ )
+ assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]]))
knn = KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True)
knn.fit(base)