summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/dtm.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2020-05-20 08:42:23 +0200
committerGard Spreemann <gspr@nonempty.org>2020-05-20 08:42:23 +0200
commit9b3079646ee3f6a494b83e864b3e10b8a93597d0 (patch)
tree63ecae8cf0d09b72907805e68f19765c7dd9694a /src/python/gudhi/point_cloud/dtm.py
parent81816dae256a9f3c0653b1d21443c3c32da7a974 (diff)
parent97e889f34e929f3c2306803b6c37b57926bd1245 (diff)
Merge tag 'tags/gudhi-release-3.2.0' into dfsg/latest
Diffstat (limited to 'src/python/gudhi/point_cloud/dtm.py')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
new file mode 100644
index 00000000..13e16d24
--- /dev/null
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -0,0 +1,70 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Marc Glisse
+#
+# Copyright (C) 2020 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+from .knn import KNearestNeighbors
+
+__author__ = "Marc Glisse"
+__copyright__ = "Copyright (C) 2020 Inria"
+__license__ = "MIT"
+
+
+class DistanceToMeasure:
+ """
+ Class to compute the distance to the empirical measure defined by a point set, as introduced in :cite:`dtm`.
+ """
+
+ def __init__(self, k, q=2, **kwargs):
+ """
+ Args:
+ k (int): number of neighbors (possibly including the point itself).
+ q (float): order used to compute the distance to measure. Defaults to 2.
+ kwargs: same parameters as :class:`~gudhi.point_cloud.knn.KNearestNeighbors`, except that
+ metric="neighbors" means that :func:`transform` expects an array with the distances
+ to the k nearest neighbors.
+ """
+ self.k = k
+ self.q = q
+ self.params = kwargs
+
+ def fit_transform(self, X, y=None):
+ return self.fit(X).transform(X)
+
+ def fit(self, X, y=None):
+ """
+ Args:
+ X (numpy.array): coordinates for mass points.
+ """
+ if self.params.setdefault("metric", "euclidean") != "neighbors":
+ self.knn = KNearestNeighbors(
+ self.k, return_index=False, return_distance=True, sort_results=False, **self.params
+ )
+ self.knn.fit(X)
+ return self
+
+ def transform(self, X):
+ """
+ Args:
+ X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed",
+ or distances to the k nearest neighbors if metric is "neighbors" (if the array has more
+ than k columns, the remaining ones are ignored).
+
+ Returns:
+ numpy.array: a 1-d array with, for each point of X, its distance to the measure defined
+ by the argument of :func:`fit`.
+ """
+ if self.params["metric"] == "neighbors":
+ distances = X[:, : self.k]
+ else:
+ distances = self.knn.transform(X)
+ distances = distances ** self.q
+ dtm = distances.sum(-1) / self.k
+ dtm = dtm ** (1.0 / self.q)
+ # We compute too many powers, 1/p in knn then q in dtm, 1/q in dtm then q or some log in the caller.
+ # Add option to skip the final root?
+ return dtm