diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-08-21 15:40:55 +0200 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-08-21 15:40:55 +0200 |
commit | e34a7a37ba9ad9e00be7d7d17d6986da06a3cdac (patch) | |
tree | 434579073380993c9ae38d84efeb684ad4fe93cc /src/python/gudhi/wasserstein/wasserstein.py | |
parent | e8cfcc7017f9002a229996232fdb2d03a41d7ea7 (diff) | |
parent | 0e620832b3b6d2e1027ead9730dce872983dd2ce (diff) |
Merge branch 'master' into alpha_complex_3d_does_not_require_eigen
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index b37d30bb..a9d1cdff 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -99,7 +99,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is `np.inf`. - :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation + :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. @@ -165,9 +165,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab # empty arrays are not handled properly by the helpers, so we avoid calling them if len(pairs_X_Y): dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) - if len(pairs_X_diag): + if len(pairs_X_diag[0]): dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p)) - if len(pairs_Y_diag): + if len(pairs_Y_diag[0]): dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw |