summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-12 10:37:27 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-12 10:37:27 +0200
commit69341c88c7c7819656c9a9b935fecc3bea50e4af (patch)
tree7fa0646180c04fb32854ca0aaf29d192d5e4118f /src/python/gudhi/wasserstein
parente94892f972357283e70c7534f84662dfaa21cc3e (diff)
parent7e05e915adc1be285e04eb00d3ab7ba1b797f38d (diff)
merge upstream/master into essential parts
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 142385b1..572d4249 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -202,8 +202,8 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
then the optimal matching will be set to `None`.
:param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
- default value is `np.inf`.
- :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ Default value is `np.inf`.
+ :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True` and with `keep_essential_parts=True`.
@@ -306,9 +306,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# empty arrays are not handled properly by the helpers, so we avoid calling them
if len(pairs_X_Y):
dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
- if len(pairs_X_diag):
+ if len(pairs_X_diag[0]):
dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
- if len(pairs_Y_diag):
+ if len(pairs_Y_diag[0]):
dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
dists = [dist.reshape(1) for dist in dists]
return ep.concatenate(dists).norms.lp(order).raw