From b0c4b1920faa512b25e7c34cbe83c8e625ac3613 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 29 Jun 2020 09:53:31 +0200 Subject: update default param for representation.metrics --- src/python/gudhi/representations/metrics.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index 8a32f7e9..3e8487c4 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -328,22 +328,28 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx) + class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01): + def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01): """ Constructor for the WassersteinDistance class. Parameters: - order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. - internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. - mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`. + internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera". delta (float): relative error 1+delta. Used only if mode == "hera". """ self.order, self.internal_p, self.mode = order, internal_p, mode - self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + if mode == "pot": + self.metric = "pot_wasserstein" + elif mode == "hera": + self.metric = "hera_wasserstein" + else: + raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'") self.delta = delta def fit(self, X, y=None): -- cgit v1.2.3