From e36fa6c9511c387447ef77e062e26671505212a2 Mon Sep 17 00:00:00 2001 From: MathieuCarriere Date: Mon, 3 Aug 2020 12:05:38 -0400 Subject: fix wasserstein autodiff --- src/python/gudhi/wasserstein/wasserstein.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index b37d30bb..fe001c37 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -165,9 +165,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab # empty arrays are not handled properly by the helpers, so we avoid calling them if len(pairs_X_Y): dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) - if len(pairs_X_diag): + if len(pairs_X_diag[0]): dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p)) - if len(pairs_Y_diag): + if len(pairs_Y_diag[0]): dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw -- cgit v1.2.3 From 47f7f50c8cdb40a0a1fd73432498004f21641803 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 4 Aug 2020 12:46:06 +0200 Subject: Remove JAX from the documentation jax.grad does not work with our functions (I think it used to work...) --- src/python/gudhi/point_cloud/knn.py | 2 +- src/python/gudhi/wasserstein/wasserstein.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/gudhi/point_cloud/knn.py b/src/python/gudhi/point_cloud/knn.py index 4652fe80..994be3b6 100644 --- a/src/python/gudhi/point_cloud/knn.py +++ b/src/python/gudhi/point_cloud/knn.py @@ -46,7 +46,7 @@ class KNearestNeighbors: sort_results (bool): if True, then distances and indices of each point are sorted on return, so that the first column contains the closest points. Otherwise, neighbors are returned in an arbitrary order. Defaults to True. - enable_autodiff (bool): if the input is a torch.tensor, jax.numpy.ndarray or tensorflow.Tensor, this + enable_autodiff (bool): if the input is a torch.tensor or tensorflow.Tensor, this instructs the function to compute distances in a way that works with automatic differentiation. This is experimental, not supported for all metrics, and requires the package EagerPy. Defaults to False. diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index fe001c37..a9d1cdff 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -99,7 +99,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab :param order: exponent for Wasserstein; Default value is 1. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is `np.inf`. - :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation + :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True`. -- cgit v1.2.3