diff options
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index 07553d6d..86008bc3 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -19,6 +19,10 @@ __license__ = "MIT" class KNearestNeighbors: """ Class wrapping several implementations for computing the k nearest neighbors in a point set. + + :Requires: `PyKeOps <installation.html#pykeops>`_, `SciPy <installation.html#scipy>`_, + `Scikit-learn <installation.html#scikit-learn>`_, and/or `Hnswlib <installation.html#hnswlib>`_ + in function of the selected `implementation`. """ def __init__(self, k, return_index=True, return_distance=False, metric="euclidean", **kwargs): @@ -200,8 +204,8 @@ class KNearestNeighbors: from joblib import Parallel, delayed, effective_n_jobs from sklearn.utils import gen_even_slices - slices = gen_even_slices(len(X), effective_n_jobs(-1)) - parallel = Parallel(backend="threading", n_jobs=-1) + slices = gen_even_slices(len(X), effective_n_jobs(n_jobs)) + parallel = Parallel(prefer="threads", n_jobs=n_jobs) if self.params.get("sort_results", True): def func(M): @@ -242,8 +246,8 @@ class KNearestNeighbors: else: func = lambda M: numpy.partition(M, k - 1)[:, 0:k] - slices = gen_even_slices(len(X), effective_n_jobs(-1)) - parallel = Parallel(backend="threading", n_jobs=-1) + slices = gen_even_slices(len(X), effective_n_jobs(n_jobs)) + parallel = Parallel(prefer="threads", n_jobs=n_jobs) distances = numpy.concatenate(parallel(delayed(func)(X[s]) for s in slices)) return distances return None |