From da2a7a68f8f57495080af37cf981f64228d165a2 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 22 Apr 2020 14:06:02 +0200 Subject: Rename local variables --- src/python/gudhi/wasserstein/wasserstein.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 42c8dc2d..3d1caeb3 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -154,17 +154,17 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a if enable_autodiff: P = ot.emd(a=a, b=b, M=M, numItermax=2000000) - pairs = np.argwhere(P[:-1, :-1]) - diag1 = np.nonzero(P[:-1, -1]) - diag2 = np.nonzero(P[-1, :-1]) + pairs_X_Y = np.argwhere(P[:-1, :-1]) + pairs_X_diag = np.nonzero(P[:-1, -1]) + pairs_Y_diag = np.nonzero(P[-1, :-1]) dists = [] # empty arrays are not handled properly by the helpers, so we avoid calling them - if len(pairs): - dists.append((Y_orig[pairs[:, 1]] - X_orig[pairs[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) - if len(diag1): - dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p)) - if len(diag2): - dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p)) + if len(pairs_X_Y): + dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) + if len(pairs_X_diag): + dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p)) + if len(pairs_Y_diag): + dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw # We can also concatenate the 3 vectors to compute just one norm. -- cgit v1.2.3