From f3b7d742580b3c93f8d1d70952a7809f6f52ca80 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Fri, 28 Feb 2020 23:23:01 +0100 Subject: Use official formula for DTM --- src/python/gudhi/clustering/tomato.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) (limited to 'src/python/gudhi/clustering') diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 19d6600f..5257e487 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -77,8 +77,8 @@ class Tomato: def fit(self, X, y=None, weights=None): """ Args: - X (?): points or distance_matrix or list of neighbors - weights (ndarray of shape (n_samples)): if density_type == 'manual', a density estimate at each point + X ((n,d)-array of float|(n,n)-array of float|Iterable[Iterable[int]]): coordinates of the points, or distance_matrix (full, not just a triangle), or list of neighbors for each point (points are represented by their index, starting from 0) + weights (ndarray of shape (n_samples)): if density_type is 'manual', a density estimate at each point """ # TODO: First detect if this is a new call with the same data (only threshold changed?) # TODO: less code duplication (subroutines?), less spaghetti, but don't compute neighbors twice if not needed. Clear error message for missing or contradictory parameters. @@ -148,7 +148,6 @@ class Tomato: if qp != 1: dd = dd ** qp weights = dd.sum(-1) - # **1/q is a waste of time, whether we take another **-.25 or a log # Back to the CPU. Not sure this is necessary, or the right way to do it. weights = numpy.array(weights) @@ -166,10 +165,13 @@ class Tomato: # weights = numpy.linalg.norm(dd, axis=1, ord=q) weights = (dd ** q).sum(-1) - # TODO: check the formula in Fred's paper if self.density_type_ == "DTM": - weights = weights ** (-0.25 / q) + # We ignore constant factors, which don't matter for + # clustering, although they do change thresholds + dim = len(self.points_[0]) + weights = weights ** (-dim / q) else: + # We ignore exponents, which become constant factors with log weights = -numpy.log(weights) if self.input_type_ == "points" and self.graph_type_ == "knn" and self.density_type_ not in {"DTM", "logDTM"}: @@ -245,9 +247,9 @@ class Tomato: # weights = numpy.linalg.norm(dd, axis=1, ord=q) weights = (dd ** q).sum(-1) - # TODO: check the formula in Fred's paper if self.density_type_ == "DTM": - weights = weights ** (-0.25 / q) + dim = len(self.points_[0]) + weights = weights ** (-dim / q) else: weights = -numpy.log(weights) @@ -258,10 +260,10 @@ class Tomato: if not k: k = self.params_["k"] q = self.params_.get("p_DTM", 2) - weights = (numpy.partition(X) ** q, k - 1).sum(-1) - # TODO: check the formula in Fred's paper + weights = (numpy.partition(X, k - 1)[:,0:k] ** q).sum(-1) if self.density_type_ == "DTM": - weights = weights ** (-0.25 / q) + dim = len(self.points_[0]) + weights = weights ** (-dim / q) else: weights = -numpy.log(weights) -- cgit v1.2.3