From e36fa6c9511c387447ef77e062e26671505212a2 Mon Sep 17 00:00:00 2001 From: MathieuCarriere Date: Mon, 3 Aug 2020 12:05:38 -0400 Subject: fix wasserstein autodiff --- src/python/gudhi/wasserstein/wasserstein.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index b37d30bb..fe001c37 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -165,9 +165,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab # empty arrays are not handled properly by the helpers, so we avoid calling them if len(pairs_X_Y): dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) - if len(pairs_X_diag): + if len(pairs_X_diag[0]): dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p)) - if len(pairs_Y_diag): + if len(pairs_Y_diag[0]): dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw -- cgit v1.2.3