summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 142385b1..572d4249 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -202,8 +202,8 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
then the optimal matching will be set to `None`.
:param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
- default value is `np.inf`.
- :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ Default value is `np.inf`.
+ :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True` and with `keep_essential_parts=True`.
@@ -306,9 +306,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# empty arrays are not handled properly by the helpers, so we avoid calling them
if len(pairs_X_Y):
dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
- if len(pairs_X_diag):
+ if len(pairs_X_diag[0]):
dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
- if len(pairs_Y_diag):
+ if len(pairs_Y_diag[0]):
dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
dists = [dist.reshape(1) for dist in dists]
return ep.concatenate(dists).norms.lp(order).raw