diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-07 13:29:46 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-07 13:29:46 +0900 |
commit | f8c1c8740f9974dcf4aaf191851d62149dceb91c (patch) | |
tree | be6d175a98a6803a216b72326e054b27e4065839 /test | |
parent | a3497b123b4802c7960a07a899ac7ce4525c5995 (diff) |
Added MAX_ITER_REACHED flag and warning
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ot.py | 52 |
1 files changed, 50 insertions, 2 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 6f0f7c9..8a19cf6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,6 +8,7 @@ import numpy as np import ot from ot.datasets import get_1D_gauss as gauss +import warnings def test_doctest(): @@ -100,9 +101,56 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) -def test_dual_variables(): - # %% parameters +def test_warnings(): + n = 100 # nb bins + m = 100 # nb bins + + mean1 = 30 + mean2 = 50 + + # bin positions + x = np.arange(n, dtype=np.float64) + y = np.arange(m, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=mean1, s=5) # m= mean, s= std + + b = gauss(m, m=mean2, s=10) + # loss matrix + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + # M/=M.max() + + # %% + + print('Computing {} EMD '.format(1)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Trigger a warning. + print('Computing {} EMD '.format(1)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1) + # Verify some things + assert "numItermax" in str(w[-1].message) + assert len(w) == 1 + # Trigger a warning. + a[0]=100 + print('Computing {} EMD '.format(2)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + # Verify some things + assert "infeasible" in str(w[-1].message) + assert len(w) == 2 + # Trigger a warning. + a[0]=-1 + print('Computing {} EMD '.format(2)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + # Verify some things + assert "infeasible" in str(w[-1].message) + assert len(w) == 3 + + +def test_dual_variables(): n = 5000 # nb bins m = 6000 # nb bins |