From 280eb9d2323837619db1ae013b929adb9b45013b Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 13 Apr 2020 01:09:45 +0200 Subject: enable_autodiff with keops This doesn't seem like the best way to handle it, we may want to handle it like a wrapper that gets the indices from knn (whatever backend) and then computes the distances. --- src/python/test/test_knn.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'src/python/test/test_knn.py') diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py index 6269df54..415c9d48 100755 --- a/src/python/test/test_knn.py +++ b/src/python/test/test_knn.py @@ -32,6 +32,12 @@ def test_knn_explicit(): .transform(query) ) assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]])) + r = ( + KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops", enable_autodiff=True) + .fit(base) + .transform(query) + ) + assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]])) knn = KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True) knn.fit(base) -- cgit v1.2.3