diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 01:09:45 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-13 01:10:05 +0200 |
commit | 280eb9d2323837619db1ae013b929adb9b45013b (patch) | |
tree | 19913f2113c38e9c7783cf5b186dd7a34dd921b2 /src/python/test/test_knn.py | |
parent | 83a1bc1fb6124a35d515f4836d2e830f3dbdf0e7 (diff) |
enable_autodiff with keops
This doesn't seem like the best way to handle it, we may want to handle
it like a wrapper that gets the indices from knn (whatever backend) and
then computes the distances.
Diffstat (limited to 'src/python/test/test_knn.py')
-rwxr-xr-x | src/python/test/test_knn.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py index 6269df54..415c9d48 100755 --- a/src/python/test/test_knn.py +++ b/src/python/test/test_knn.py @@ -32,6 +32,12 @@ def test_knn_explicit(): .transform(query) ) assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]])) + r = ( + KNearestNeighbors(2, metric="chebyshev", return_distance=True, return_index=False, implementation="keops", enable_autodiff=True) + .fit(base) + .transform(query) + ) + assert r == pytest.approx(np.array([[0.0, 1], [1, 1], [1, 2]])) knn = KNearestNeighbors(2, metric="minkowski", p=3, return_distance=False, return_index=True) knn.fit(base) |