summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/knn.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-14 18:27:19 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-14 18:27:19 +0200
commit9518287cfa2a62948ede2e7d17d5c9f29092e0f4 (patch)
treeafa7c1eb5217528b92f4f51b93815108537a067a /src/python/gudhi/point_cloud/knn.py
parentb908205e85bbe29c8d18ad1f38e783a1327434d7 (diff)
Doc improvements
Diffstat (limited to 'src/python/gudhi/point_cloud/knn.py')
-rw-r--r--src/python/gudhi/point_cloud/knn.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 8b3cdb46..d7cf0b2a 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -38,9 +38,9 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
- enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.array or similar, this instructs
- the function to compute distances in a way that works with automatic differentiation.
- This is experimental and not supported for all implementations.
+ enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this
+ instructs the function to compute distances in a way that works with automatic differentiation.
+ This is experimental and not supported for all metrics. Defaults to False.
kwargs: additional parameters are forwarded to the backends.
"""
self.k = k
@@ -124,6 +124,11 @@ class KNearestNeighbors:
"""
Args:
X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed".
+
+ Returns:
+ numpy.array: if return_index, an array of shape (len(X), k) with the indices (in the argument
+ of :func:`fit`) of the k nearest neighbors to the points of X. If return_distance, an array of the
+ same shape with the distances to those neighbors. If both, a tuple with the two arrays, in this order.
"""
if self.params.get("enable_autodiff", False):
# pykeops does not support autodiff for kmin yet, but when it does in the future,