diff options
author | Vincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com> | 2020-07-01 03:09:08 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-01 03:09:08 -0700 |
commit | 8d316e831c6af51efb9c362a5b203528a9fd3b15 (patch) | |
tree | d06a69eaea859b36a15a90812d6c2de7f7018633 /src/python/gudhi/wasserstein/wasserstein.py | |
parent | d58fe553c806874d723c7e6d4fee368cc308e474 (diff) | |
parent | 0b4de61a18bc30f66a7fb45cc246cff2f55ba1a1 (diff) |
Merge pull request #349 from tlacombe/fix342
fix #342
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 89ecab1c..b37d30bb 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p): def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). - :param order: exponent for Wasserstein. Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param order: exponent for Wasserstein. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2). :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. :type enable_autodiff: bool @@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff): return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) -def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False): +def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). @@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a :param matching: if True, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention (-1) represents the diagonal. - :param order: exponent for Wasserstein; Default value is 2. + :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); - Default value is 2 (Euclidean norm). + Default value is `np.inf`. :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. |