summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering/tomato.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-02-28 23:23:01 +0100
committerMarc Glisse <marc.glisse@inria.fr>2020-02-28 23:23:01 +0100
commitf3b7d742580b3c93f8d1d70952a7809f6f52ca80 (patch)
treecb1a7ff2c6b44dbe5216617ae7a9ecd273b47954 /src/python/gudhi/clustering/tomato.py
parente5e0f9a9e96389eadc9e9c4bc493b88abcb6f89a (diff)
Use official formula for DTM
Diffstat (limited to 'src/python/gudhi/clustering/tomato.py')
-rw-r--r--src/python/gudhi/clustering/tomato.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 19d6600f..5257e487 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -77,8 +77,8 @@ class Tomato:
def fit(self, X, y=None, weights=None):
"""
Args:
- X (?): points or distance_matrix or list of neighbors
- weights (ndarray of shape (n_samples)): if density_type == 'manual', a density estimate at each point
+ X ((n,d)-array of float|(n,n)-array of float|Iterable[Iterable[int]]): coordinates of the points, or distance_matrix (full, not just a triangle), or list of neighbors for each point (points are represented by their index, starting from 0)
+ weights (ndarray of shape (n_samples)): if density_type is 'manual', a density estimate at each point
"""
# TODO: First detect if this is a new call with the same data (only threshold changed?)
# TODO: less code duplication (subroutines?), less spaghetti, but don't compute neighbors twice if not needed. Clear error message for missing or contradictory parameters.
@@ -148,7 +148,6 @@ class Tomato:
if qp != 1:
dd = dd ** qp
weights = dd.sum(-1)
- # **1/q is a waste of time, whether we take another **-.25 or a log
# Back to the CPU. Not sure this is necessary, or the right way to do it.
weights = numpy.array(weights)
@@ -166,10 +165,13 @@ class Tomato:
# weights = numpy.linalg.norm(dd, axis=1, ord=q)
weights = (dd ** q).sum(-1)
- # TODO: check the formula in Fred's paper
if self.density_type_ == "DTM":
- weights = weights ** (-0.25 / q)
+ # We ignore constant factors, which don't matter for
+ # clustering, although they do change thresholds
+ dim = len(self.points_[0])
+ weights = weights ** (-dim / q)
else:
+ # We ignore exponents, which become constant factors with log
weights = -numpy.log(weights)
if self.input_type_ == "points" and self.graph_type_ == "knn" and self.density_type_ not in {"DTM", "logDTM"}:
@@ -245,9 +247,9 @@ class Tomato:
# weights = numpy.linalg.norm(dd, axis=1, ord=q)
weights = (dd ** q).sum(-1)
- # TODO: check the formula in Fred's paper
if self.density_type_ == "DTM":
- weights = weights ** (-0.25 / q)
+ dim = len(self.points_[0])
+ weights = weights ** (-dim / q)
else:
weights = -numpy.log(weights)
@@ -258,10 +260,10 @@ class Tomato:
if not k:
k = self.params_["k"]
q = self.params_.get("p_DTM", 2)
- weights = (numpy.partition(X) ** q, k - 1).sum(-1)
- # TODO: check the formula in Fred's paper
+ weights = (numpy.partition(X, k - 1)[:,0:k] ** q).sum(-1)
if self.density_type_ == "DTM":
- weights = weights ** (-0.25 / q)
+ dim = len(self.points_[0])
+ weights = weights ** (-dim / q)
else:
weights = -numpy.log(weights)