diff options
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/representations/metrics.py | 17 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 10 |
2 files changed, 17 insertions, 10 deletions
diff --git a/src/python/gudhi/representations/metrics.py b/src/python/gudhi/representations/metrics.py index cf2e0879..142ddef1 100644 --- a/src/python/gudhi/representations/metrics.py +++ b/src/python/gudhi/representations/metrics.py @@ -350,23 +350,30 @@ class PersistenceFisherDistance(BaseEstimator, TransformerMixin): """ return _persistence_fisher_distance(diag1, diag2, bandwidth=self.bandwidth, kernel_approx=self.kernel_approx) + class WassersteinDistance(BaseEstimator, TransformerMixin): """ This is a class for computing the Wasserstein distance matrix from a list of persistence diagrams. """ - def __init__(self, order=2, internal_p=2, mode="pot", delta=0.01, n_jobs=None): + + def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01, n_jobs=None): """ Constructor for the WassersteinDistance class. Parameters: - order (int): exponent for Wasserstein, default value is 2., see :func:`gudhi.wasserstein.wasserstein_distance`. - internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is 2 (euclidean norm), see :func:`gudhi.wasserstein.wasserstein_distance`. - mode (str): method for computing Wasserstein distance. Either "pot" or "hera". + order (int): exponent for Wasserstein, default value is 1., see :func:`gudhi.wasserstein.wasserstein_distance`. + internal_p (int): ground metric on the (upper-half) plane (i.e. norm l_p in R^2), default value is `np.inf`, see :func:`gudhi.wasserstein.wasserstein_distance`. + mode (str): method for computing Wasserstein distance. Either "pot" or "hera". Default set to "hera". delta (float): relative error 1+delta. Used only if mode == "hera". n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details. """ self.order, self.internal_p, self.mode = order, internal_p, mode - self.metric = "pot_wasserstein" if mode == "pot" else "hera_wasserstein" + if mode == "pot": + self.metric = "pot_wasserstein" + elif mode == "hera": + self.metric = "hera_wasserstein" + else: + raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'") self.delta = delta self.n_jobs = n_jobs diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 89ecab1c..b37d30bb 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p): def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). - :param order: exponent for Wasserstein. Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param order: exponent for Wasserstein. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2). :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. :type enable_autodiff: bool @@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff): return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) -def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False): +def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). @@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a :param matching: if True, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention (-1) represents the diagonal. - :param order: exponent for Wasserstein; Default value is 2. + :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); - Default value is 2 (Euclidean norm). + Default value is `np.inf`. :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. |