summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:29:46 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:29:46 +0900
commitf8c1c8740f9974dcf4aaf191851d62149dceb91c (patch)
treebe6d175a98a6803a216b72326e054b27e4065839 /test
parenta3497b123b4802c7960a07a899ac7ce4525c5995 (diff)
Added MAX_ITER_REACHED flag and warning
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py52
1 files changed, 50 insertions, 2 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 6f0f7c9..8a19cf6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,6 +8,7 @@ import numpy as np
import ot
from ot.datasets import get_1D_gauss as gauss
+import warnings
def test_doctest():
@@ -100,9 +101,56 @@ def test_emd2_multi():
np.testing.assert_allclose(emd1, emdn)
-def test_dual_variables():
- # %% parameters
+def test_warnings():
+ n = 100 # nb bins
+ m = 100 # nb bins
+
+ mean1 = 30
+ mean2 = 50
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+ y = np.arange(m, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=mean1, s=5) # m= mean, s= std
+
+ b = gauss(m, m=mean2, s=10)
+ # loss matrix
+ M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
+ # M/=M.max()
+
+ # %%
+
+ print('Computing {} EMD '.format(1))
+ G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ with warnings.catch_warnings(record=True) as w:
+ # Cause all warnings to always be triggered.
+ warnings.simplefilter("always")
+ # Trigger a warning.
+ print('Computing {} EMD '.format(1))
+ G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1)
+ # Verify some things
+ assert "numItermax" in str(w[-1].message)
+ assert len(w) == 1
+ # Trigger a warning.
+ a[0]=100
+ print('Computing {} EMD '.format(2))
+ G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ # Verify some things
+ assert "infeasible" in str(w[-1].message)
+ assert len(w) == 2
+ # Trigger a warning.
+ a[0]=-1
+ print('Computing {} EMD '.format(2))
+ G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ # Verify some things
+ assert "infeasible" in str(w[-1].message)
+ assert len(w) == 3
+
+
+def test_dual_variables():
n = 5000 # nb bins
m = 6000 # nb bins