diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-19 09:06:08 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-19 09:06:08 +0200 |
commit | b2a9ba18ce33778abdd9f5032af4bfff04e8bbd2 (patch) | |
tree | 00be082047fd62acdc4be33922819936f2da47ac /src/python/gudhi/wasserstein/wasserstein.py | |
parent | f93c403b81b4ccb98bfad8e4ef30cdf0e7333f6c (diff) |
Unwrap the result
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 9660b99b..f0c82962 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -71,6 +71,7 @@ def _perstot(X, order, internal_p, enable_autodiff): ''' if enable_autodiff: import eagerpy as ep + return _perstot_autodiff(ep.astensor(X), order, internal_p).raw else: return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) @@ -118,6 +119,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a if enable_autodiff: import eagerpy as ep + X_orig = ep.astensor(X) Y_orig = ep.astensor(Y) X = X_orig.numpy() @@ -140,10 +142,10 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a return ot_cost ** (1./order) , match if enable_autodiff: - P = ot.emd(a=a,b=b,M=M, numItermax=2000000) + P = ot.emd(a=a, b=b, M=M, numItermax=2000000) pairs = np.argwhere(P[:-1, :-1]) - diag2 = np.nonzero(P[-1, :-1]) diag1 = np.nonzero(P[:-1, -1]) + diag2 = np.nonzero(P[-1, :-1]) dists = [] # empty arrays are not handled properly by the helpers, so we avoid calling them if len(pairs): @@ -152,8 +154,8 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p)) if len(diag2): dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p)) - dists = [ dist.reshape(1) for dist in dists ] - return ep.concatenate(dists).norms.lp(order) + dists = [dist.reshape(1) for dist in dists] + return ep.concatenate(dists).norms.lp(order).raw # Should just compute the L^order norm manually? # We can also concatenate the 3 vectors to compute just one norm. |