diff options
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index bb7757f2..8369f1f8 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -33,6 +33,9 @@ class KNN: p (float): norm L^p on input points (including numpy.inf) if metric is "minkowski". Defaults to 2. n_jobs (int): number of jobs to schedule for parallel processing of nearest neighbors on the CPU. If -1 is given all processors are used. Default: 1. + sort_results (bool): if True, then distances and indices of each point are + sorted on return, so that the first column contains the closest points. + Otherwise, neighbors are returned in an arbitrary order. Defaults to True. kwargs: additional parameters are forwarded to the backends. """ self.k = k @@ -115,18 +118,22 @@ class KNN: X = numpy.array(X) if self.return_index: neighbors = numpy.argpartition(X, k - 1)[:, 0:k] - distances = numpy.take_along_axis(X, neighbors, axis=-1) - ngb_order = numpy.argsort(distances, axis=-1) - neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1) + if self.params.get("sort_results", True): + X = numpy.take_along_axis(X, neighbors, axis=-1) + ngb_order = numpy.argsort(X, axis=-1) + neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1) + else: + ngb_order = neighbors if self.return_distance: - distances = numpy.take_along_axis(distances, ngb_order, axis=-1) + distances = numpy.take_along_axis(X, ngb_order, axis=-1) return neighbors, distances else: return neighbors if self.return_distance: distances = numpy.partition(X, k - 1)[:, 0:k] - # partition is not guaranteed to sort the lower half, although it often does - distances.sort(axis=-1) + if self.params.get("sort_results"): + # partition is not guaranteed to sort the lower half, although it often does + distances.sort(axis=-1) return distances return None |