From 2fb0d594060958804239fcdad5336832ea5133d0 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 25 May 2020 22:56:04 +0200 Subject: Add test --- src/python/gudhi/clustering/tomato.py | 18 ++++++++---------- src/python/test/test_tomato.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) create mode 100755 src/python/test/test_tomato.py diff --git a/src/python/gudhi/clustering/tomato.py b/src/python/gudhi/clustering/tomato.py index 55b64c1d..e3eaa300 100644 --- a/src/python/gudhi/clustering/tomato.py +++ b/src/python/gudhi/clustering/tomato.py @@ -18,24 +18,24 @@ class Tomato: Attributes ---------- n_clusters_: int - The number of clusters. Writing to it automatically adjusts labels_. + The number of clusters. Writing to it automatically adjusts `labels_`. merge_threshold_: float - minimum prominence of a cluster so it doesn't get merged. Writing to it automatically adjusts labels_. + minimum prominence of a cluster so it doesn't get merged. Writing to it automatically adjusts `labels_`. n_leaves_: int number of leaves (unstable clusters) in the hierarchical tree leaf_labels_: ndarray of shape (n_samples,) cluster labels for each point, at the very bottom of the hierarchy labels_: ndarray of shape (n_samples,) cluster labels for each point, after merging - diagram_: ndarray of shape (n_leaves_, 2) + diagram_: ndarray of shape (`n_leaves_`, 2) persistence diagram (only the finite points) max_weight_per_cc_: ndarray of shape (n_connected_components,) maximum of the density function on each connected component. This corresponds to the abscissa of infinite points in the diagram - children_: ndarray of shape (n_leaves_-n_connected_components, 2) - The children of each non-leaf node. Values less than n_leaves_ correspond to leaves of the tree. - A node i greater than or equal to n_leaves_ is a non-leaf node and has children children_[i - n_leaves_]. - Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_leaves_ + i + children_: ndarray of shape (`n_leaves_`-n_connected_components, 2) + The children of each non-leaf node. Values less than `n_leaves_` correspond to leaves of the tree. + A node i greater than or equal to `n_leaves_` is a non-leaf node and has children children_[i - `n_leaves_`]. + Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node `n_leaves_` + i weights_: ndarray of shape (n_samples,) weights of the points, as computed by the density estimator or provided by the user params_: dict @@ -53,8 +53,6 @@ class Tomato: **params ): """ - Each parameter has a corresponding attribute, like self.merge_threshold_, that can be changed later. - Args: graph_type (str): 'manual', 'knn' or 'radius'. density_type (str): 'manual', 'DTM', 'logDTM', 'KDE' or 'logKDE'. @@ -223,7 +221,7 @@ class Tomato: self.neighbors_[j].add(i) self.weights_ = weights - self.leaf_labels_, self.children_, self.diagram_, self.max_weight_per_cc_ = doit(list(self.neighbors_), weights) + self.leaf_labels_, self.children_, self.diagram_, self.max_weight_per_cc_ = doit(self.neighbors_, weights) self.n_leaves_ = len(self.max_weight_per_cc_) + len(self.children_) assert self.leaf_labels_.max() + 1 == len(self.max_weight_per_cc_) + len(self.children_) # TODO: deduplicate this code with the setters below diff --git a/src/python/test/test_tomato.py b/src/python/test/test_tomato.py new file mode 100755 index 00000000..0a33b86e --- /dev/null +++ b/src/python/test/test_tomato.py @@ -0,0 +1,27 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Marc Glisse + + Copyright (C) 2020 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.clustering.tomato import Tomato +import numpy as np +import pytest + + +def test_tomato_something(): + a = [(1, 2), (1.1, 1.9), (0.9, 1.8), (10, 0), (10.1, 0.05), (10.2, -0.1), (5.4, 0)] + t = Tomato(metric="euclidean", n_clusters=2, k=4, n_jobs=-1, eps=0.05) + assert np.array_equal(t.fit_predict(a), [1,1,1,0,0,0,0]) # or with swapped 0 and 1 + + t = Tomato(density_type='KDE', r=1, k=4) + t.fit(a) + assert np.array_equal(t.leaf_labels_, [1,1,1,0,0,0,0]) # or with swapped 0 and 1 + + t = Tomato(graph_type='radius', r=4.7, k=4) + t.fit(a) + assert t.max_weight_per_cc_.size == 2 -- cgit v1.2.3