From 9b4de0e29a01552b4bb3f47fe0d3f01f5601c000 Mon Sep 17 00:00:00 2001 From: martinroyer <16647869+martinroyer@users.noreply.github.com> Date: Tue, 9 Jun 2020 08:42:30 +0200 Subject: Apply suggestions from code review --- src/python/gudhi/representations/vector_methods.py | 45 ++++++++++++++++------ 1 file changed, 33 insertions(+), 12 deletions(-) (limited to 'src/python/gudhi/representations') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index df66ffc3..a09b9356 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -577,26 +577,26 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin): return self.fit_transform([diag])[0,:] def _lapl_contrast(measure, centers, inertias, eps=1e-8): - """contrast function for vectorising `measure` in ATOL""" + """contrast function for vectorising `measure` in ATOL""" return np.exp(-np.sqrt(pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps))) def _gaus_contrast(measure, centers, inertias, eps=1e-8): - """contrast function for vectorising `measure` in ATOL""" + """contrast function for vectorising `measure` in ATOL""" return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps)) def _indicator_contrast(diags, centers, inertias, eps=1e-8): - """contrast function for vectorising `measure` in ATOL""" + """contrast function for vectorising `measure` in ATOL""" pair_dist = pairwise.pairwise_distances(diags, Y=centers) flat_circ = (pair_dist < (inertias+eps)).astype(int) robe_curve = np.positive((2-pair_dist/(inertias+eps))*((inertias+eps) < pair_dist).astype(int)) return flat_circ + robe_curve def _cloud_weighting(measure): - """automatic uniform weighting with mass 1 for `measure` in ATOL""" + """automatic uniform weighting with mass 1 for `measure` in ATOL""" return np.ones(shape=measure.shape[0]) def _iidproba_weighting(measure): - """automatic uniform weighting with mass 1/N for `measure` in ATOL""" + """automatic uniform weighting with mass 1/N for `measure` in ATOL""" return np.ones(shape=measure.shape[0]) / measure.shape[0] class Atol(BaseEstimator, TransformerMixin): @@ -611,20 +611,41 @@ class Atol(BaseEstimator, TransformerMixin): Parameters: quantiser (Object): Object with `fit` (sklearn API consistent) and `cluster_centers` and `n_clusters` - attributes (default: MiniBatchKMeans()). This object will be fitted by the function `fit`. - weighting_method (function): constant generic function for weighting the measure points + attributes. This object will be fitted by the function `fit`. + weighting_method (string): constant generic function for weighting the measure points choose from {"cloud", "iidproba"} (default: constant function, i.e. the measure is seen as a point cloud by default). This will have no impact if weights are provided along with measures all the way: `fit` and `transform`. contrast (string): constant function for evaluating proximity of a measure with respect to centers - choose from {"gaus", "lapl", "indi"} + choose from {"gaussian", "laplacian", "indicator"} (default: laplacian contrast function, see page 3 in the ATOL paper). - """ + + Example + -------- + >>> from sklearn.cluster import KMeans + >>> import numpy as np + >>> a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]]) + >>> b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) + >>> c = np.array([[3, 2, -1], [1, 2, -1]]) + >>> atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2)) + >>> atol_vectoriser.fit(X=[a, b, c]) + >>> atol_vectoriser.centers + array([[ 2.6 , 2.8 , -0.4 ], + [ 2. , 0.66666667, 3.33333333]]) + >>> atol_vectoriser(a) + array([0.58394704, 1.0769395 ]) + >>> atol_vectoriser(c) + array([1.02816136, 0.23559623]) + >>> atol_vectoriser.transform(X=[a, b, c]) + array([[0.58394704, 1.0769395 ], + [1.04696684, 0.56203292], + [1.02816136, 0.23559623]]) + """ self.quantiser = quantiser self.contrast = { - "gaus": _gaus_contrast, - "lapl": _lapl_contrast, - "indi": _indicator_contrast, + "gaussian": _gaus_contrast, + "laplacian": _lapl_contrast, + "indicator": _indicator_contrast, }.get(contrast, _gaus_contrast) self.centers = np.ones(shape=(self.quantiser.n_clusters, 2))*np.inf self.inertias = np.full(self.quantiser.n_clusters, np.nan) -- cgit v1.2.3