summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-19 09:06:08 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-19 09:06:08 +0200
commitb2a9ba18ce33778abdd9f5032af4bfff04e8bbd2 (patch)
tree00be082047fd62acdc4be33922819936f2da47ac /src/python/gudhi/wasserstein
parentf93c403b81b4ccb98bfad8e4ef30cdf0e7333f6c (diff)
Unwrap the result
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 9660b99b..f0c82962 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -71,6 +71,7 @@ def _perstot(X, order, internal_p, enable_autodiff):
'''
if enable_autodiff:
import eagerpy as ep
+
return _perstot_autodiff(ep.astensor(X), order, internal_p).raw
else:
return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
@@ -118,6 +119,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
if enable_autodiff:
import eagerpy as ep
+
X_orig = ep.astensor(X)
Y_orig = ep.astensor(Y)
X = X_orig.numpy()
@@ -140,10 +142,10 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
return ot_cost ** (1./order) , match
if enable_autodiff:
- P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
+ P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
pairs = np.argwhere(P[:-1, :-1])
- diag2 = np.nonzero(P[-1, :-1])
diag1 = np.nonzero(P[:-1, -1])
+ diag2 = np.nonzero(P[-1, :-1])
dists = []
# empty arrays are not handled properly by the helpers, so we avoid calling them
if len(pairs):
@@ -152,8 +154,8 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p))
if len(diag2):
dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p))
- dists = [ dist.reshape(1) for dist in dists ]
- return ep.concatenate(dists).norms.lp(order)
+ dists = [dist.reshape(1) for dist in dists]
+ return ep.concatenate(dists).norms.lp(order).raw
# Should just compute the L^order norm manually?
# We can also concatenate the 3 vectors to compute just one norm.