summaryrefslogtreecommitdiff
path: root/src/python/gudhi/representations
diff options
context:
space:
mode:
authormartinroyer <16647869+martinroyer@users.noreply.github.com>2020-06-10 10:20:13 +0200
committerGitHub <noreply@github.com>2020-06-10 10:20:13 +0200
commitcdba6045ddf1dd41e8addb7351d1c87a5506ba0f (patch)
treed90e03c1471fb3e154990cbf71050d5ceca70132 /src/python/gudhi/representations
parent3d126356fd3fcaeb2bde8824b8c5894450fccdd9 (diff)
Apply suggestions from code review
Diffstat (limited to 'src/python/gudhi/representations')
-rw-r--r--src/python/gudhi/representations/vector_methods.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 77b2836f..667f963b 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -578,17 +578,17 @@ class ComplexPolynomial(BaseEstimator, TransformerMixin):
def _lapl_contrast(measure, centers, inertias, eps=1e-8):
"""contrast function for vectorising `measure` in ATOL"""
- return np.exp(-np.sqrt(pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps)))
+ return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps))
def _gaus_contrast(measure, centers, inertias, eps=1e-8):
"""contrast function for vectorising `measure` in ATOL"""
- return np.exp(-pairwise.pairwise_distances(measure, Y=centers) / (inertias + eps))
+ return np.exp(-pairwise.pairwise_distances(measure, Y=centers)**2 / (inertias**2 + eps))
def _indicator_contrast(diags, centers, inertias, eps=1e-8):
"""contrast function for vectorising `measure` in ATOL"""
pair_dist = pairwise.pairwise_distances(diags, Y=centers)
flat_circ = (pair_dist < (inertias+eps)).astype(int)
- robe_curve = np.positive((2-pair_dist/(inertias+eps))*((inertias+eps) < pair_dist).astype(int))
+ robe_curve = np.clip(2-pair_dist/(inertias+eps), 0, 1)
return flat_circ + robe_curve
def _cloud_weighting(measure):
@@ -638,7 +638,7 @@ class Atol(BaseEstimator, TransformerMixin):
(default: constant function, i.e. the measure is seen as a point cloud by default).
This will have no impact if weights are provided along with measures all the way: `fit` and `transform`.
contrast (string): constant function for evaluating proximity of a measure with respect to centers
- choose from {"gaussian", "laplacian", "indicator"}
+ choose from {"gaussian", "laplacian", "indicator"}
(default: gaussian contrast function, see page 3 in the ATOL paper).
"""
self.quantiser = quantiser
@@ -670,7 +670,7 @@ class Atol(BaseEstimator, TransformerMixin):
"""
if not hasattr(self.quantiser, 'fit'):
raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser))
- if len(X) < self.quantiser.n_clusters:
+ if np.sum([measure.shape[0] for measure in X]) < self.quantiser.n_clusters:
# in case there are not enough observations for fitting the quantiser, we add random points in [0, 1]^2
# @Martin: perhaps this behaviour is to be externalised and a warning should be raised instead
random_points = np.random.rand(self.quantiser.n_clusters-len(X), X[0].shape[1])
@@ -681,7 +681,6 @@ class Atol(BaseEstimator, TransformerMixin):
measures_concat = np.concatenate(X)
self.quantiser.fit(X=measures_concat, sample_weight=sample_weight)
self.centers = self.quantiser.cluster_centers_
- labels = np.argmin(pairwise.pairwise_distances(measures_concat, Y=self.centers), axis=1)
dist_centers = pairwise.pairwise_distances(self.centers)
np.fill_diagonal(dist_centers, np.inf)
self.inertias = np.min(dist_centers, axis=0)/2