diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2019-10-31 08:48:15 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2019-10-31 08:48:15 +0100 |
commit | ee4934750e8c9dbdee4874d56921aeb9bf7b7bb7 (patch) | |
tree | cb8d67d0b1b944351cf5dbcfed932b4fc253e41a /src/python/gudhi | |
parent | 3c76f73a530daacd48d476cd96bd946e4ab6d78a (diff) |
Increase numItermax in the call to POT.
This number is pretty arbitrary...
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/wasserstein.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py index 445772e4..eba7c6d5 100644 --- a/src/python/gudhi/wasserstein.py +++ b/src/python/gudhi/wasserstein.py @@ -92,7 +92,8 @@ def wasserstein_distance(X, Y, p=2., q=2.): # Comptuation of the otcost using the ot.emd2 library. # Note: it is the squared Wasserstein distance. - ot_cost = (n+m) * ot.emd2(a, b, M) + # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? + ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000) return ot_cost ** (1./p) |