diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 114 |
1 files changed, 110 insertions, 4 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index acd8718..ea6d9dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -4,12 +4,15 @@ # # License: MIT License +import warnings + import numpy as np + import ot +from ot.datasets import get_1D_gauss as gauss def test_doctest(): - import doctest # test lp solver @@ -66,9 +69,6 @@ def test_emd_empty(): def test_emd2_multi(): - - from ot.datasets import get_1D_gauss as gauss - n = 1000 # nb bins # bin positions @@ -100,3 +100,109 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) + + # emd loss multipro proc with log + ot.tic() + emdn = ot.emd2(a, b, M, log=True, return_matrix=True) + ot.toc('multi proc : {} s') + + for i in range(len(emdn)): + emd = emdn[i] + log = emd[1] + cost = emd[0] + check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost) + emdn[i] = cost + + emdn = np.array(emdn) + np.testing.assert_allclose(emd1, emdn) + + +def test_warnings(): + n = 100 # nb bins + m = 100 # nb bins + + mean1 = 30 + mean2 = 50 + + # bin positions + x = np.arange(n, dtype=np.float64) + y = np.arange(m, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=mean1, s=5) # m= mean, s= std + + b = gauss(m, m=mean2, s=10) + + # loss matrix + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + + print('Computing {} EMD '.format(1)) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + print('Computing {} EMD '.format(1)) + ot.emd(a, b, M, numItermax=1) + assert "numItermax" in str(w[-1].message) + assert len(w) == 1 + a[0] = 100 + print('Computing {} EMD '.format(2)) + ot.emd(a, b, M) + assert "infeasible" in str(w[-1].message) + assert len(w) == 2 + a[0] = -1 + print('Computing {} EMD '.format(2)) + ot.emd(a, b, M) + assert "infeasible" in str(w[-1].message) + assert len(w) == 3 + + +def test_dual_variables(): + n = 5000 # nb bins + m = 6000 # nb bins + + mean1 = 1000 + mean2 = 1100 + + # bin positions + x = np.arange(n, dtype=np.float64) + y = np.arange(m, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=mean1, s=5) # m= mean, s= std + + b = gauss(m, m=mean2, s=10) + + # loss matrix + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + + print('Computing {} EMD '.format(1)) + + # emd loss 1 proc + ot.tic() + G, log = ot.emd(a, b, M, log=True) + ot.toc('1 proc : {} s') + + ot.tic() + G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) + ot.toc('1 proc : {} s') + + cost1 = (G * M).sum() + # Check symmetry + np.testing.assert_array_almost_equal(cost1, (M * G2.T).sum()) + # Check with closed-form solution for gaussians + np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2)) + + # Check that both cost computations are equivalent + np.testing.assert_almost_equal(cost1, log['cost']) + check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) + + +def check_duality_gap(a, b, M, G, u, v, cost): + cost_dual = np.vdot(a, u) + np.vdot(b, v) + # Check that dual and primal cost are equal + np.testing.assert_almost_equal(cost_dual, cost) + + [ind1, ind2] = np.nonzero(G) + + # Check that reduced cost is zero on transport arcs + np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2], + np.zeros(ind1.size)) |