summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/gudhi/point_cloud/knn.py78
-rwxr-xr-xsrc/python/test/test_dtm.py3
-rwxr-xr-xsrc/python/test/test_knn.py8
3 files changed, 73 insertions, 16 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 6642a3c2..f6870517 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -115,25 +115,71 @@ class KNearestNeighbors:
if metric == "precomputed":
# scikit-learn could handle that, but they insist on calling fit() with an unused square array, which is too unnatural.
- X = numpy.array(X)
if self.return_index:
- neighbors = numpy.argpartition(X, k - 1)[:, 0:k]
- if self.params.get("sort_results", True):
- X = numpy.take_along_axis(X, neighbors, axis=-1)
- ngb_order = numpy.argsort(X, axis=-1)
- neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1)
+ n_jobs = self.params.get("n_jobs", 1)
+ # Supposedly numpy can be compiled with OpenMP and handle this, but nobody does that?!
+ if n_jobs == 1:
+ neighbors = numpy.argpartition(X, k - 1)[:, 0:k]
+ if self.params.get("sort_results", True):
+ X = numpy.take_along_axis(X, neighbors, axis=-1)
+ ngb_order = numpy.argsort(X, axis=-1)
+ neighbors = numpy.take_along_axis(neighbors, ngb_order, axis=-1)
+ else:
+ ngb_order = neighbors
+ if self.return_distance:
+ distances = numpy.take_along_axis(X, ngb_order, axis=-1)
+ return neighbors, distances
+ else:
+ return neighbors
else:
- ngb_order = neighbors
- if self.return_distance:
- distances = numpy.take_along_axis(X, ngb_order, axis=-1)
- return neighbors, distances
- else:
- return neighbors
+ from joblib import Parallel, delayed, effective_n_jobs
+ from sklearn.utils import gen_even_slices
+
+ slices = gen_even_slices(len(X), effective_n_jobs(-1))
+ parallel = Parallel(backend="threading", n_jobs=-1)
+ if self.params.get("sort_results", True):
+
+ def func(M):
+ neighbors = numpy.argpartition(M, k - 1)[:, 0:k]
+ Y = numpy.take_along_axis(M, neighbors, axis=-1)
+ ngb_order = numpy.argsort(Y, axis=-1)
+ return numpy.take_along_axis(neighbors, ngb_order, axis=-1)
+
+ else:
+
+ def func(M):
+ return numpy.argpartition(M, k - 1)[:, 0:k]
+
+ neighbors = numpy.concatenate(parallel(delayed(func)(X[s]) for s in slices))
+ if self.return_distance:
+ distances = numpy.take_along_axis(X, neighbors, axis=-1)
+ return neighbors, distances
+ else:
+ return neighbors
if self.return_distance:
- distances = numpy.partition(X, k - 1)[:, 0:k]
- if self.params.get("sort_results"):
- # partition is not guaranteed to sort the lower half, although it often does
- distances.sort(axis=-1)
+ n_jobs = self.params.get("n_jobs", 1)
+ if n_jobs == 1:
+ distances = numpy.partition(X, k - 1)[:, 0:k]
+ if self.params.get("sort_results"):
+ # partition is not guaranteed to sort the lower half, although it often does
+ distances.sort(axis=-1)
+ else:
+ from joblib import Parallel, delayed, effective_n_jobs
+ from sklearn.utils import gen_even_slices
+
+ if self.params.get("sort_results"):
+
+ def func(M):
+ # Not partitioning in place, because we should not modify the user's array?
+ r = numpy.partition(M, k - 1)[:, 0:k]
+ r.sort(axis=-1)
+ return r
+
+ else:
+ func = lambda M: numpy.partition(M, k - 1)[:, 0:k]
+ slices = gen_even_slices(len(X), effective_n_jobs(-1))
+ parallel = Parallel(backend="threading", n_jobs=-1)
+ distances = numpy.concatenate(parallel(delayed(func)(X[s]) for s in slices))
return distances
return None
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 37934fdb..bc0d3698 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -33,6 +33,9 @@ def test_dtm_compare_euclidean():
dtm = DistanceToMeasure(k, metric="precomputed")
r4 = dtm.fit_transform(d)
assert r4 == pytest.approx(r0)
+ dtm = DistanceToMeasure(k, metric="precomputed", n_jobs=2)
+ r4b = dtm.fit_transform(d)
+ assert r4b == pytest.approx(r0)
dtm = DistanceToMeasure(k, implementation="keops")
r5 = dtm.fit_transform(pts)
assert r5 == pytest.approx(r0)
diff --git a/src/python/test/test_knn.py b/src/python/test/test_knn.py
index 6aac2006..6269df54 100755
--- a/src/python/test/test_knn.py
+++ b/src/python/test/test_knn.py
@@ -52,6 +52,14 @@ def test_knn_explicit():
r = knn.fit_transform(dist)
assert np.array_equal(r[0], [[0, 1], [1, 0], [2, 0]])
assert np.array_equal(r[1], [[0, 3], [0, 1], [0, 1]])
+ # Second time in parallel
+ knn = KNearestNeighbors(2, metric="precomputed", return_index=True, return_distance=False, n_jobs=2)
+ r = knn.fit_transform(dist)
+ assert np.array_equal(r, [[0, 1], [1, 0], [2, 0]])
+ knn = KNearestNeighbors(2, metric="precomputed", return_index=True, return_distance=True, n_jobs=2)
+ r = knn.fit_transform(dist)
+ assert np.array_equal(r[0], [[0, 1], [1, 0], [2, 0]])
+ assert np.array_equal(r[1], [[0, 3], [0, 1], [0, 1]])
def test_knn_compare():