summaryrefslogtreecommitdiff
path: root/src/python/gudhi/point_cloud/dtm.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-03-26 22:10:26 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-03-26 22:10:26 +0100
commitc8c942c43643131a7ef9899826a7095e497150fe (patch)
treeaec910d4ddebe7dd942d62ca5dfb09b814f06ede /src/python/gudhi/point_cloud/dtm.py
parentf9a0e1ec856f26c08e7b6493df076bb70d775551 (diff)
cmake
Diffstat (limited to 'src/python/gudhi/point_cloud/dtm.py')
-rw-r--r--src/python/gudhi/point_cloud/dtm.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/python/gudhi/point_cloud/dtm.py b/src/python/gudhi/point_cloud/dtm.py
new file mode 100644
index 00000000..08f9ea60
--- /dev/null
+++ b/src/python/gudhi/point_cloud/dtm.py
@@ -0,0 +1,40 @@
+from .knn import KNN
+
+
+class DTM:
+ def __init__(self, k, q=2, **kwargs):
+ """
+ Args:
+ q (float): order used to compute the distance to measure. Defaults to the dimension, or 2 if input_type is 'distance_matrix'.
+ kwargs: Same parameters as KNN, except that metric="neighbors" means that transform() expects an array with the distances to the k nearest neighbors.
+ """
+ self.k = k
+ self.q = q
+ self.params = kwargs
+
+ def fit_transform(self, X, y=None):
+ return self.fit(X).transform(X)
+
+ def fit(self, X, y=None):
+ """
+ Args:
+ X (numpy.array): coordinates for mass points
+ """
+ if self.params.setdefault("metric", "euclidean") != "neighbors":
+ self.knn = KNN(self.k, return_index=False, return_distance=True, **self.params)
+ self.knn.fit(X)
+ return self
+
+ def transform(self, X):
+ """
+ Args:
+ X (numpy.array): coordinates for query points, or distance matrix if metric is "precomputed", or distances to the k nearest neighbors if metric is "neighbors" (if the array has more than k columns, the remaining ones are ignored).
+ """
+ if self.params["metric"] == "neighbors":
+ distances = X[:, : self.k]
+ else:
+ distances = self.knn.transform(X)
+ distances = distances ** self.q
+ dtm = distances.sum(-1) / self.k
+ dtm = dtm ** (1.0 / self.q)
+ return dtm