From 4fe1d853d32846b8f34699c7121b4d7edfeb48ac Mon Sep 17 00:00:00 2001 From: tlacombe Date: Thu, 11 Jun 2020 12:15:22 +0200 Subject: added file --- src/python/gudhi/wasserstein/wasserstein.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 89ecab1c..6d7ec65f 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff): return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) -def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False): +def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). -- cgit v1.2.3 From b0c4b1920faa512b25e7c34cbe83c8e625ac3613 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 29 Jun 2020 09:53:31 +0200 Subject: update default param for representation.metrics --- src/python/gudhi/representations/metrics.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index 8a32f7e9..3e8487c4 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -328,22 +328,28 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx) + class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01): + def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01): """ Constructor for the WassersteinDistance class. Parameters: - order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. - internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. - mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`. + internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera". delta (float): relative error 1+delta. Used only if mode == "hera". """ self.order, self.internal_p, self.mode = order, internal_p, mode - self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + if mode == "pot": + self.metric = "pot_wasserstein" + elif mode == "hera": + self.metric = "hera_wasserstein" + else: + raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'") self.delta = delta def fit(self, X, y=None): -- cgit v1.2.3 From 6c65d29acc3b03d21beca653834340787bf0c65e Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 29 Jun 2020 10:21:13 +0200 Subject: update doc in wasserstein.wasserstein to reflect change in default param --- src/python/gudhi/wasserstein/wasserstein.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 6d7ec65f..b37d30bb 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p): def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). - :param order: exponent for Wasserstein. Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param order: exponent for Wasserstein. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2). :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. :type enable_autodiff: bool @@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab :param matching: if True, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention (-1) represents the diagonal. - :param order: exponent for Wasserstein; Default value is 2. + :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); - Default value is 2 (Euclidean norm). + Default value is `np.inf`. :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. -- cgit v1.2.3