summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-28 12:45:00 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-28 12:45:00 +0100
commit7f323484acdeafca93efdd9bdd20ed428f8fb95b (patch)
treecc1b69d50936d6cb028feb150b47a507653412bd /src/python/gudhi/point_cloud/knn.py
parent40f4b6fb1fe20c3843b1fd80f99996e6d25c9426 (diff)
Optional sort_results
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index bb7757f2..8369f1f8 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -33,6 +33,9 @@ class KNN:
p (float): norm L^p on input points (including numpy.inf) if metric is "minkowski". Defaults to 2.
n_jobs (int): number of jobs to schedule for parallel processing of nearest neighbors on the CPU.
If -1 is given all processors are used. Default: 1.
+ sort_results (bool): if True, then distances and indices of each point are
+ sorted on return, so that the first column contains the closest points.
+ Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
@@ -115,18 +118,22 @@ class KNN:
X = numpy.array(X)
if self.return_index:
neighbors = numpy.argpartition(X, k - 1)[:, 0:k]
- distances = numpy.take_along_axis(X, neighbors, axis=-1)
- ngb_order = numpy.argsort(distances, axis=-1)
- neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1)
+ if self.params.get("sort_results", True):
+ X = numpy.take_along_axis(X, neighbors, axis=-1)
+ ngb_order = numpy.argsort(X, axis=-1)
+ neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1)
+ else:
+ ngb_order = neighbors
if self.return_distance:
- distances = numpy.take_along_axis(distances, ngb_order, axis=-1)
+ distances = numpy.take_along_axis(X, ngb_order, axis=-1)
return neighbors, distances
else:
return neighbors
if self.return_distance:
distances = numpy.partition(X, k - 1)[:, 0:k]
- # partition is not guaranteed to sort the lower half, although it often does
- distances.sort(axis=-1)
+ if self.params.get("sort_results"):
+ # partition is not guaranteed to sort the lower half, although it often does
+ distances.sort(axis=-1)
return distances
return None