summaryrefslogtreecommitdiff
path: root/src/python/gudhi/clustering
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-28 21:04:01 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-28 21:04:01 +0200
commit4c9c01ab794901210859e299c3528be8f26f5f27 (patch)
tree934e20cb388c7fd72ccd8b6671a603b387db5d2d /src/python/gudhi/clustering
parentb84b3f006805eb69e03983301a550ddcb8050769 (diff)
Unbreak KDE
Diffstat (limited to 'src/python/gudhi/clustering')
-rw-r--r--src/python/gudhi/clustering/tomato.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py
index 07006d7c..88a1a34d 100644
--- a/src/python/gudhi/clustering/tomato.py
+++ b/src/python/gudhi/clustering/tomato.py
@@ -229,7 +229,6 @@ class Tomato:
# 'float64' is slow except on super expensive GPUs. Allow it with some param?
XX = torch.tensor(self.points_, dtype=torch.float32)
if p == numpy.inf:
- assert False # Not supported???
dd = (LazyTensor(XX[:, None, :]) - LazyTensor(XX[None, :, :])).abs().max(-1).Kmin(k_DTM, dim=1)
elif p == 2: # Any even integer?
dd = ((LazyTensor(XX[:, None, :]) - LazyTensor(XX[None, :, :])) ** p).sum(-1).Kmin(k_DTM, dim=1)
@@ -284,10 +283,10 @@ class Tomato:
if self.density_type_ in {"KDE", "logKDE"}:
# FIXME: replace most assert with raise ValueError("blabla")
- assert input_type == "points"
+ # assert input_type == "points"
kde_params = self.params_.get("kde_params", dict())
from sklearn.neighbors import KernelDensity
- weights = KernelDensity(**kde_params).fit(X).score_samples(X)
+ weights = KernelDensity(**kde_params).fit(self.points_).score_samples(self.points_)
if self.density_type_ == "KDE":
weights = numpy.exp(weights)