summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 42c8dc2d..3d1caeb3 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -154,17 +154,17 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a
if enable_autodiff:
P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
- pairs = np.argwhere(P[:-1, :-1])
- diag1 = np.nonzero(P[:-1, -1])
- diag2 = np.nonzero(P[-1, :-1])
+ pairs_X_Y = np.argwhere(P[:-1, :-1])
+ pairs_X_diag = np.nonzero(P[:-1, -1])
+ pairs_Y_diag = np.nonzero(P[-1, :-1])
dists = []
# empty arrays are not handled properly by the helpers, so we avoid calling them
- if len(pairs):
- dists.append((Y_orig[pairs[:, 1]] - X_orig[pairs[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
- if len(diag1):
- dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p))
- if len(diag2):
- dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p))
+ if len(pairs_X_Y):
+ dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
+ if len(pairs_X_diag):
+ dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
+ if len(pairs_Y_diag):
+ dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
dists = [dist.reshape(1) for dist in dists]
return ep.concatenate(dists).norms.lp(order).raw
# We can also concatenate the 3 vectors to compute just one norm.