summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-08-04 13:22:01 +0200
committerGitHub <noreply@github.com>2020-08-04 13:22:01 +0200
commitf98133a79b8a34ade8a0c214f75264e955420e7e (patch)
treec5fdbd3505989b1913aacb4a4320c7fbf53071b4
parent51081053fe2fddd518303b6521a61dc7fbdab4a8 (diff)
parent47f7f50c8cdb40a0a1fd73432498004f21641803 (diff)
Merge pull request #371 from MathieuCarriere/master
Wasserstein autodiff
-rw-r--r--src/python/gudhi/point_cloud/knn.py2
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py6
2 files changed, 4 insertions, 4 deletions
diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py
index 4652fe80..994be3b6 100644
--- a/src/python/gudhi/point_cloud/knn.py
+++ b/src/python/gudhi/point_cloud/knn.py
@@ -46,7 +46,7 @@ class KNearestNeighbors:
sort_results (bool): if True, then distances and indices of each point are
sorted on return, so that the first column contains the closest points.
Otherwise, neighbors are returned in an arbitrary order. Defaults to True.
- enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this
+ enable_autodiff (bool): if the input is a torch.tensor or tensorflow.Tensor, this
instructs the function to compute distances in a way that works with automatic differentiation.
This is experimental, not supported for all metrics, and requires the package EagerPy.
Defaults to False.
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index b37d30bb..a9d1cdff 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -99,7 +99,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
:param order: exponent for Wasserstein; Default value is 1.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
Default value is `np.inf`.
- :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True`.
@@ -165,9 +165,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# empty arrays are not handled properly by the helpers, so we avoid calling them
if len(pairs_X_Y):
dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
- if len(pairs_X_diag):
+ if len(pairs_X_diag[0]):
dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
- if len(pairs_Y_diag):
+ if len(pairs_Y_diag[0]):
dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
dists = [dist.reshape(1) for dist in dists]
return ep.concatenate(dists).norms.lp(order).raw