summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/doc/point_cloud.rst17
-rw-r--r--src/python/gudhi/point_cloud/dtm.py6
-rw-r--r--src/python/gudhi/point_cloud/knn.py31
-rwxr-xr-xsrc/python/test/test_dtm.py2
4 files changed, 39 insertions, 17 deletions
diff --git a/src/python/doc/point_cloud.rst b/src/python/doc/point_cloud.rst
index c0d4b303..351b0786 100644
--- a/src/python/doc/point_cloud.rst
+++ b/src/python/doc/point_cloud.rst
@@ -21,10 +21,23 @@ Subsampling
:special-members:
:show-inheritance:
-TimeDelayEmbedding
-------------------
+Time Delay Embedding
+--------------------
.. autoclass:: gudhi.point_cloud.timedelay.TimeDelayEmbedding
:members:
:special-members: __call__
+Nearest neighbors
+-----------------
+
+.. automodule:: gudhi.point_cloud.knn
+ :members:
+ :special-members: __init__
+
+Distance to measure
+-------------------
+
+.. automodule:: gudhi.point_cloud.dtm
+ :members:
+ :special-members: __init__
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
index 541b74a6..e4096c5e 100644
--- a/src/python/gudhi/point_cloud/dtm.py
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -11,11 +11,15 @@ from .knn import KNN
class DTM:
+ """
+ Class to compute the distance to the empirical measure defined by a point set.
+ """
+
def __init__(self, k, q=2, **kwargs):
"""
Args:
q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if input_type is 'distance_matrix'.
- kwargs: Same parameters as KNN, except that metric="neighbors" means that transform() expects an array with the distances to the k nearest neighbors.
+ kwargs: Same parameters as :class:`~gudhi.point_cloud.knn.KNN`, except that metric="neighbors" means that :func:`transform` expects an array with the distances to the k nearest neighbors.
"""
self.k = k
self.q = q
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index a4ea3acd..02448530 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -11,6 +11,10 @@ import numpy
class KNN:
+ """
+ Class wrapping several implementations for computing the k nearest neighbors in a point set.
+ """
+
def __init__(self, k, return_index=True, return_distance=False, metric="euclidean", **kwargs):
"""
Args:
@@ -19,22 +23,17 @@ class KNN:
return_distance (bool): if True, return the distance to each neighbor.
implementation (str): Choice of the library that does the real work.
- * 'keops' for a brute-force, CUDA implementation through pykeops. Useful when the dimension becomes
- large (10+) but the number of points remains low (less than a million).
- Only "minkowski" and its aliases are supported.
+ * 'keops' for a brute-force, CUDA implementation through pykeops. Useful when the dimension becomes large (10+) but the number of points remains low (less than a million). Only "minkowski" and its aliases are supported.
* 'ckdtree' for scipy's cKDTree. Only "minkowski" and its aliases are supported.
- * 'sklearn' for scikit-learn's NearestNeighbors.
- Note that this provides in particular an option algorithm="brute".
- * 'hnsw' for hnswlib.Index. It is very fast but does not provide guarantees.
- Only supports "euclidean" for now.
+ * 'sklearn' for scikit-learn's NearestNeighbors. Note that this provides in particular an option algorithm="brute".
+ * 'hnsw' for hnswlib.Index. It can be very fast but does not provide guarantees. Only supports "euclidean" for now.
* None will try to select a sensible one (scipy if possible, scikit-learn otherwise).
metric (str): see `sklearn.neighbors.NearestNeighbors`.
eps (float): relative error when computing nearest neighbors with the cKDTree.
p (float): norm L^p on input points (including numpy.inf) if metric is "minkowski". Defaults to 2.
n_jobs (int): Number of jobs to schedule for parallel processing of nearest neighbors on the CPU.
If -1 is given all processors are used. Default: 1.
-
- Additional parameters are forwarded to the backends.
+ kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
self.return_index = return_index
@@ -75,20 +74,26 @@ class KNN:
if self.params["implementation"] == "ckdtree":
# sklearn could handle this, but it is much slower
from scipy.spatial import cKDTree
+
self.kdtree = cKDTree(X)
if self.params["implementation"] == "sklearn" and self.metric != "precomputed":
# FIXME: sklearn badly handles "precomputed"
from sklearn.neighbors import NearestNeighbors
- nargs = {k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"}}
+ nargs = {
+ k: v for k, v in self.params.items() if k in {"p", "n_jobs", "metric_params", "algorithm", "leaf_size"}
+ }
self.nn = NearestNeighbors(self.k, metric=self.metric, **nargs)
self.nn.fit(X)
if self.params["implementation"] == "hnsw":
import hnswlib
- self.graph = hnswlib.Index("l2", len(X[0])) # Actually returns squared distances
- self.graph.init_index(len(X), **{k:v for k,v in self.params.items() if k in {"ef_construction", "M", "random_seed"}})
+
+ self.graph = hnswlib.Index("l2", len(X[0])) # Actually returns squared distances
+ self.graph.init_index(
+ len(X), **{k: v for k, v in self.params.items() if k in {"ef_construction", "M", "random_seed"}}
+ )
n = self.params.get("num_threads")
if n is None:
n = self.params.get("n_jobs", 1)
@@ -154,7 +159,7 @@ class KNN:
p = self.params["p"]
if p == numpy.inf:
- # Requires a version of pykeops strictly more recent than 1.3
+ # Requires pykeops 1.4 or later
mat = (LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])).abs().max(-1)
elif p == 2: # Any even integer?
mat = ((LazyTensor(XX[:, None, :]) - LazyTensor(YY[None, :, :])) ** p).sum(-1)
diff --git a/src/python/test/test_dtm.py b/src/python/test/test_dtm.py
index 841f8c3c..93b13e1a 100755
--- a/src/python/test/test_dtm.py
+++ b/src/python/test/test_dtm.py
@@ -47,4 +47,4 @@ def test_dtm_precomputed():
dist = numpy.array([[2.0, 2], [0, 1], [3, 4]])
dtm = DTM(2, q=2, metric="neighbors")
r = dtm.fit_transform(dist)
- assert r == pytest.approx([2.0, .707, 3.5355], rel=.01)
+ assert r == pytest.approx([2.0, 0.707, 3.5355], rel=0.01)