summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein/wasserstein.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 89ecab1c..b37d30bb 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -73,8 +73,8 @@ def _perstot_autodiff(X, order, internal_p):
def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
- :param order: exponent for Wasserstein. Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
+ :param order: exponent for Wasserstein.
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2).
:param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
transparent to automatic differentiation.
:type enable_autodiff: bool
@@ -88,7 +88,7 @@ def _perstot(X, order, internal_p, enable_autodiff):
return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
-def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False):
+def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
@@ -96,9 +96,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
:param matching: if True, computes and returns the optimal matching between X and Y, encoded as
a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
the j-th point in Y, with the convention (-1) represents the diagonal.
- :param order: exponent for Wasserstein; Default value is 2.
+ :param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
- Default value is 2 (Euclidean norm).
+ Default value is `np.inf`.
:param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True`.