diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-03-26 23:39:59 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-03-26 23:39:59 +0100 |
commit | af35ea5b4ce631ae826f1db1940798f254aba658 (patch) | |
tree | 0edb91b4cfbdb1ef9ee393d15fadfe4ffc31cf7e | |
parent | 7120b186471828a9570fdeef37900bd8b98d0d31 (diff) |
clean-up use of "implementation"
-rw-r--r-- | src/python/gudhi/point_cloud/knn.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index 943d4e9f..a4ea3acd 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -72,12 +72,12 @@ class KNN: X (numpy.array): coordinates for reference points """ self.ref_points = X - if self.params.get("implementation") == "ckdtree": + if self.params["implementation"] == "ckdtree": # sklearn could handle this, but it is much slower from scipy.spatial import cKDTree self.kdtree = cKDTree(X) - if self.params.get("implementation") == "sklearn" and self.metric != "precomputed": + if self.params["implementation"] == "sklearn" and self.metric != "precomputed": # FIXME: sklearn badly handles "precomputed" from sklearn.neighbors import NearestNeighbors @@ -85,7 +85,7 @@ class KNN: self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs) self.nn.fit(X) - if self.params.get("implementation") == "hnsw": + if self.params["implementation"] == "hnsw": import hnswlib self.graph = hnswlib.Index("l2", len(X[0])) # Actually returns squared distances self.graph.init_index(len(X), **{k:v for k,v in self.params.items() if k in {"ef_construction", "M", "random_seed"}}) @@ -125,7 +125,7 @@ class KNN: return distances return None - if self.params.get("implementation") == "hnsw": + if self.params["implementation"] == "hnsw": ef = self.params.get("ef") if ef is not None: self.graph.set_ef(ef) @@ -141,7 +141,7 @@ class KNN: return numpy.sqrt(distances) return None - if self.params.get("implementation") == "keops": + if self.params["implementation"] == "keops": import torch from pykeops.torch import LazyTensor @@ -178,7 +178,7 @@ class KNN: return None # FIXME: convert everything back to numpy arrays or not? - if hasattr(self, "kdtree"): + if self.params["implementation"] == "ckdtree": qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}} distances, neighbors = self.kdtree.query(X, k=self.k, **qargs) if self.return_index: @@ -190,6 +190,7 @@ class KNN: return distances return None + assert self.params["implementation"] == "sklearn" if self.return_distance: distances, neighbors = self.nn.kneighbors(X, return_distance=True) if self.return_index: |