summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2020-08-03 12:05:38 -0400
committerMarc Glisse <marc.glisse@inria.fr>2020-08-04 12:42:55 +0200
commite36fa6c9511c387447ef77e062e26671505212a2 (patch)
treeef14a625efcf0a4592e1925852f9df2cf4f1f7af
parent51081053fe2fddd518303b6521a61dc7fbdab4a8 (diff)
fix wasserstein autodiff
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index b37d30bb..fe001c37 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -165,9 +165,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# empty arrays are not handled properly by the helpers, so we avoid calling them
if len(pairs_X_Y):
dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
- if len(pairs_X_diag):
+ if len(pairs_X_diag[0]):
dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
- if len(pairs_Y_diag):
+ if len(pairs_Y_diag[0]):
dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
dists = [dist.reshape(1) for dist in dists]
return ep.concatenate(dists).norms.lp(order).raw