summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2019-10-31 08:48:15 +0100
committerMarc Glisse <marc.glisse@inria.fr>2019-10-31 08:48:15 +0100
commitee4934750e8c9dbdee4874d56921aeb9bf7b7bb7 (patch)
treecb8d67d0b1b944351cf5dbcfed932b4fc253e41a /src/python/gudhi/wasserstein.py
parent3c76f73a530daacd48d476cd96bd946e4ab6d78a (diff)
Increase numItermax in the call to POT.
This number is pretty arbitrary...
Diffstat (limited to 'src/python/gudhi/wasserstein.py')
-rw-r--r--src/python/gudhi/wasserstein.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index 445772e4..eba7c6d5 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -92,7 +92,8 @@ def wasserstein_distance(X, Y, p=2., q=2.):
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the squared Wasserstein distance.
- ot_cost = (n+m) * ot.emd2(a, b, M)
+ # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
+ ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000)
return ot_cost ** (1./p)