summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-26 23:39:59 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-26 23:39:59 +0100
commitaf35ea5b4ce631ae826f1db1940798f254aba658 (patch)
tree0edb91b4cfbdb1ef9ee393d15fadfe4ffc31cf7e /src/python/gudhi/point_cloud
parent7120b186471828a9570fdeef37900bd8b98d0d31 (diff)
clean-up use of "implementation"
Diffstat (limited to 'src/python/gudhi/point_cloud')
-rw-r--r--src/python/gudhi/point_cloud/knn.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 943d4e9f..a4ea3acd 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -72,12 +72,12 @@ class KNN:
X (numpy.array): coordinates for reference points
"""
self.ref_points = X
- if self.params.get("implementation") == "ckdtree":
+ if self.params["implementation"] == "ckdtree":
# sklearn could handle this, but it is much slower
from scipy.spatial import cKDTree
self.kdtree = cKDTree(X)
- if self.params.get("implementation") == "sklearn" and self.metric != "precomputed":
+ if self.params["implementation"] == "sklearn" and self.metric != "precomputed":
# FIXME: sklearn badly handles "precomputed"
from sklearn.neighbors import NearestNeighbors
@@ -85,7 +85,7 @@ class KNN:
self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs)
self.nn.fit(X)
- if self.params.get("implementation") == "hnsw":
+ if self.params["implementation"] == "hnsw":
import hnswlib
self.graph = hnswlib.Index("l2", len(X[0])) # Actually returns squared distances
self.graph.init_index(len(X), **{k:v for k,v in self.params.items() if k in {"ef_construction", "M", "random_seed"}})
@@ -125,7 +125,7 @@ class KNN:
return distances
return None
- if self.params.get("implementation") == "hnsw":
+ if self.params["implementation"] == "hnsw":
ef = self.params.get("ef")
if ef is not None:
self.graph.set_ef(ef)
@@ -141,7 +141,7 @@ class KNN:
return numpy.sqrt(distances)
return None
- if self.params.get("implementation") == "keops":
+ if self.params["implementation"] == "keops":
import torch
from pykeops.torch import LazyTensor
@@ -178,7 +178,7 @@ class KNN:
return None
# FIXME: convert everything back to numpy arrays or not?
- if hasattr(self, "kdtree"):
+ if self.params["implementation"] == "ckdtree":
qargs = {key: val for key, val in self.params.items() if key in {"p", "eps", "n_jobs"}}
distances, neighbors = self.kdtree.query(X, k=self.k, **qargs)
if self.return_index:
@@ -190,6 +190,7 @@ class KNN:
return distances
return None
+ assert self.params["implementation"] == "sklearn"
if self.return_distance:
distances, neighbors = self.nn.kneighbors(X, return_distance=True)
if self.return_index: