From 76529cae58f8a2736a1730fd81a9e12c3f4c7e19 Mon Sep 17 00:00:00 2001 From: martinroyer <16647869+martinroyer@users.noreply.github.com> Date: Thu, 11 Jun 2020 16:47:16 +0200 Subject: Apply suggestions from code review #456 (thank you Marc!) --- src/python/gudhi/representations/vector_methods.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) (limited to 'src/python/gudhi/representations/vector_methods.py') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index ede1087f..49c05c51 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -586,10 +586,8 @@ def _gaus_contrast(measure, centers, inertias): def _indicator_contrast(diags, centers, inertias): """contrast function for vectorising `measure` in ATOL""" - pair_dist = pairwise.pairwise_distances(diags, Y=centers) - flat_circ = (pair_dist < inertias).astype(int) - robe_curve = np.clip(2-pair_dist/inertias, 0, 1) - return flat_circ + robe_curve + robe_curve = np.clip(2-pairwise.pairwise_distances(diags, Y=centers)/inertias, 0, 1) + return robe_curve def _cloud_weighting(measure): """automatic uniform weighting with mass 1 for `measure` in ATOL""" @@ -603,7 +601,7 @@ class Atol(BaseEstimator, TransformerMixin): """ This class allows to vectorise measures (e.g. point clouds, persistence diagrams, etc) after a quantisation step. - ATOL paper: https://arxiv.org/abs/1909.13472 + ATOL paper: :cite:`royer2019atol` Example -------- @@ -632,9 +630,9 @@ class Atol(BaseEstimator, TransformerMixin): Parameters: quantiser (Object): Object with `fit` (sklearn API consistent) and `cluster_centers` and `n_clusters` - attributes. This object will be fitted by the function `fit`. + attributes, e.g. sklearn.cluster.KMeans. It will be fitted when the Atol object function `fit` is called. weighting_method (string): constant generic function for weighting the measure points - choose from {"cloud", "iidproba"} + choose from {"cloud", "iidproba"} (default: constant function, i.e. the measure is seen as a point cloud by default). This will have no impact if weights are provided along with measures all the way: `fit` and `transform`. contrast (string): constant function for evaluating proximity of a measure with respect to centers @@ -647,8 +645,6 @@ class Atol(BaseEstimator, TransformerMixin): "laplacian": _lapl_contrast, "indicator": _indicator_contrast, }.get(contrast, _gaus_contrast) - self.centers = np.ones(shape=(self.quantiser.n_clusters, 2))*np.inf - self.inertias = np.full(self.quantiser.n_clusters, np.nan) self.weighting_method = { "cloud" : _cloud_weighting, "iidproba": _iidproba_weighting, @@ -670,11 +666,6 @@ class Atol(BaseEstimator, TransformerMixin): """ if not hasattr(self.quantiser, 'fit'): raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser)) - if np.sum([measure.shape[0] for measure in X]) < self.quantiser.n_clusters: - # in case there are not enough observations for fitting the quantiser, we add random points in [0, 1]^2 - # @Martin: perhaps this behaviour is to be externalised and a warning should be raised instead - random_points = np.random.rand(self.quantiser.n_clusters-len(X), X[0].shape[1]) - X.append(random_points) if sample_weight is None: sample_weight = np.concatenate([self.weighting_method(measure) for measure in X]) -- cgit v1.2.3