diff options
author | tlacombe <lacombe1993@gmail.com> | 2021-04-12 10:37:27 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2021-04-12 10:37:27 +0200 |
commit | 69341c88c7c7819656c9a9b935fecc3bea50e4af (patch) | |
tree | 7fa0646180c04fb32854ca0aaf29d192d5e4118f /src/python/gudhi/wasserstein | |
parent | e94892f972357283e70c7534f84662dfaa21cc3e (diff) | |
parent | 7e05e915adc1be285e04eb00d3ab7ba1b797f38d (diff) |
merge upstream/master into essential parts
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 142385b1..572d4249 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -202,8 +202,8 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab then the optimal matching will be set to `None`. :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); - default value is `np.inf`. - :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation + Default value is `np.inf`. + :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True` and with `keep_essential_parts=True`. @@ -306,9 +306,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab # empty arrays are not handled properly by the helpers, so we avoid calling them if len(pairs_X_Y): dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) - if len(pairs_X_diag): + if len(pairs_X_diag[0]): dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p)) - if len(pairs_Y_diag): + if len(pairs_Y_diag[0]): dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw |