summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations/vector_methods.py
diff options
context:
space:
mode:
authormartinroyer <16647869+martinroyer@users.noreply.github.com>2020-06-11 16:47:16 +0200
committerGitHub <noreply@github.com>2020-06-11 16:47:16 +0200
commit76529cae58f8a2736a1730fd81a9e12c3f4c7e19 (patch)
treebbd889924f81f5737b3df46846338ec7f8ba95b9 /src/python/gudhi/representations/vector_methods.py
parentb7a3e9ee0065d70438bacc1bae09272f9be7adaf (diff)
Apply suggestions from code review #456
(thank you Marc!)
Diffstat (limited to 'src/python/gudhi/representations/vector_methods.py')
-rw-r--r--src/python/gudhi/representations/vector_methods.py19
1 files changed, 5 insertions, 14 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index ede1087f..49c05c51 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -586,10 +586,8 @@ def _gaus_contrast(measure, centers, inertias):
def _indicator_contrast(diags, centers, inertias):
"""contrast function for vectorising `measure` in ATOL"""
- pair_dist = pairwise.pairwise_distances(diags, Y=centers)
- flat_circ = (pair_dist < inertias).astype(int)
- robe_curve = np.clip(2-pair_dist/inertias, 0, 1)
- return flat_circ + robe_curve
+ robe_curve = np.clip(2-pairwise.pairwise_distances(diags, Y=centers)/inertias, 0, 1)
+ return robe_curve
def _cloud_weighting(measure):
"""automatic uniform weighting with mass 1 for `measure` in ATOL"""
@@ -603,7 +601,7 @@ class Atol(BaseEstimator, TransformerMixin):
"""
This class allows to vectorise measures (e.g. point clouds, persistence diagrams, etc) after a quantisation step.
- ATOL paper: https://arxiv.org/abs/1909.13472
+ ATOL paper: :cite:`royer2019atol`
Example
--------
@@ -632,9 +630,9 @@ class Atol(BaseEstimator, TransformerMixin):
Parameters:
quantiser (Object): Object with `fit` (sklearn API consistent) and `cluster_centers` and `n_clusters`
- attributes. This object will be fitted by the function `fit`.
+ attributes, e.g. sklearn.cluster.KMeans. It will be fitted when the Atol object function `fit` is called.
weighting_method (string): constant generic function for weighting the measure points
- choose from {"cloud", "iidproba"}
+ choose from {"cloud", "iidproba"}
(default: constant function, i.e. the measure is seen as a point cloud by default).
This will have no impact if weights are provided along with measures all the way: `fit` and `transform`.
contrast (string): constant function for evaluating proximity of a measure with respect to centers
@@ -647,8 +645,6 @@ class Atol(BaseEstimator, TransformerMixin):
"laplacian": _lapl_contrast,
"indicator": _indicator_contrast,
}.get(contrast, _gaus_contrast)
- self.centers = np.ones(shape=(self.quantiser.n_clusters, 2))*np.inf
- self.inertias = np.full(self.quantiser.n_clusters, np.nan)
self.weighting_method = {
"cloud" : _cloud_weighting,
"iidproba": _iidproba_weighting,
@@ -670,11 +666,6 @@ class Atol(BaseEstimator, TransformerMixin):
"""
if not hasattr(self.quantiser, 'fit'):
raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser))
- if np.sum([measure.shape[0] for measure in X]) < self.quantiser.n_clusters:
- # in case there are not enough observations for fitting the quantiser, we add random points in [0, 1]^2
- # @Martin: perhaps this behaviour is to be externalised and a warning should be raised instead
- random_points = np.random.rand(self.quantiser.n_clusters-len(X), X[0].shape[1])
- X.append(random_points)
if sample_weight is None:
sample_weight = np.concatenate([self.weighting_method(measure) for measure in X])