From cdba6045ddf1dd41e8addb7351d1c87a5506ba0f Mon Sep 17 00:00:00 2001 From: martinroyer <16647869+martinroyer@users.noreply.github.com> Date: Wed, 10 Jun 2020 10:20:13 +0200 Subject: Apply suggestions from code review --- src/python/gudhi/representations/vector_methods.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'src/python/gudhi/representations') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 77b2836f..667f963b 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -578,17 +578,17 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin): def _lapl_contrast(measure, centers, inertias, eps=1e-8): """contrast function for vectorising `measure` in ATOL""" - return np.exp(-np.sqrt(pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps))) + return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps)) def _gaus_contrast(measure, centers, inertias, eps=1e-8): """contrast function for vectorising `measure` in ATOL""" - return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps)) + return np.exp(-pairwise.pairwise_distances(measure, Y=centers)**2 / (inertias**2 + eps)) def _indicator_contrast(diags, centers, inertias, eps=1e-8): """contrast function for vectorising `measure` in ATOL""" pair_dist = pairwise.pairwise_distances(diags, Y=centers) flat_circ = (pair_dist < (inertias+eps)).astype(int) - robe_curve = np.positive((2-pair_dist/(inertias+eps))*((inertias+eps) < pair_dist).astype(int)) + robe_curve = np.clip(2-pair_dist/(inertias+eps), 0, 1) return flat_circ + robe_curve def _cloud_weighting(measure): @@ -638,7 +638,7 @@ class Atol(BaseEstimator, TransformerMixin): (default: constant function, i.e. the measure is seen as a point cloud by default). This will have no impact if weights are provided along with measures all the way: `fit` and `transform`. contrast (string): constant function for evaluating proximity of a measure with respect to centers - choose from {"gaussian", "laplacian", "indicator"} + choose from {"gaussian", "laplacian", "indicator"} (default: gaussian contrast function, see page 3 in the ATOL paper). """ self.quantiser = quantiser @@ -670,7 +670,7 @@ class Atol(BaseEstimator, TransformerMixin): """ if not hasattr(self.quantiser, 'fit'): raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser)) - if len(X) < self.quantiser.n_clusters: + if np.sum([measure.shape[0] for measure in X]) < self.quantiser.n_clusters: # in case there are not enough observations for fitting the quantiser, we add random points in [0, 1]^2 # @Martin: perhaps this behaviour is to be externalised and a warning should be raised instead random_points = np.random.rand(self.quantiser.n_clusters-len(X), X[0].shape[1]) @@ -681,7 +681,6 @@ class Atol(BaseEstimator, TransformerMixin): measures_concat = np.concatenate(X) self.quantiser.fit(X=measures_concat, sample_weight=sample_weight) self.centers = self.quantiser.cluster_centers_ - labels = np.argmin(pairwise.pairwise_distances(measures_concat, Y=self.centers), axis=1) dist_centers = pairwise.pairwise_distances(self.centers) np.fill_diagonal(dist_centers, np.inf) self.inertias = np.min(dist_centers, axis=0)/2 -- cgit v1.2.3