From 8c9a1c674dcacc8b66e88897b6116561bb811ffa Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 11 May 2020 21:55:21 +0200 Subject: Handle k=1 in KNearestNeighbors with SciPy --- src/python/gudhi/point_cloud/knn.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'src/python/gudhi/point_cloud') diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index 34e80b5d..65896847 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -302,6 +302,10 @@ class KNearestNeighbors: if self.params["implementation"] == "ckdtree": qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}} distances, neighbors = self.kdtree.query(X, k=self.k, **qargs) + if k == 1: + # SciPy decided to squeeze the last dimension for k=1 + distances = distances[:, None] + neighbors = neighbors[:, None] if self.return_index: if self.return_distance: return neighbors, distances -- cgit v1.2.3