diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-07-02 21:11:18 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-07-02 21:11:18 +0200 |
commit | 0f364d372a6bce81d895d4ccd066174bad260e9e (patch) | |
tree | c9067237d5790125fbcbb8fd784064f2b413a5ee /src/python/gudhi/wasserstein/wasserstein.py | |
parent | eedb34f25d76cb3dc7ccb6b59a60217a26eedfcd (diff) | |
parent | 3c7a4d01ec758d68a219fae8981c9847cf8d7a0f (diff) |
Merge branch 'master' into edge_collapse_integration_vincent
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 89ecab1c..b37d30bb 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p): def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). - :param order: exponent for Wasserstein. Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param order: exponent for Wasserstein. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2). :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. :type enable_autodiff: bool @@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff): return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) -def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False): +def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). @@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a :param matching: if True, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention (-1) represents the diagonal. - :param order: exponent for Wasserstein; Default value is 2. + :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); - Default value is 2 (Euclidean norm). + Default value is `np.inf`. :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. |