From d52b4ea415d9bb669be04ccd0940f9b3d258d0e1 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 5 Sep 2017 17:15:45 +0900 Subject: Fixed typo and merged emd tests --- test/test_ot.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 3 deletions(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index acd8718..ded6c9f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -6,6 +6,8 @@ import numpy as np import ot + +from ot.datasets import get_1D_gauss as gauss def test_doctest(): @@ -66,9 +68,6 @@ def test_emd_empty(): def test_emd2_multi(): - - from ot.datasets import get_1D_gauss as gauss - n = 1000 # nb bins # bin positions @@ -100,3 +99,64 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) + +def test_dual_variables(): + #%% parameters + + n=5000 # nb bins + m=6000 # nb bins + + mean1 = 1000 + mean2 = 1100 + + # bin positions + x=np.arange(n,dtype=np.float64) + y=np.arange(m,dtype=np.float64) + + # Gaussian distributions + a=gauss(n,m=mean1,s=5) # m= mean, s= std + + b=gauss(m,m=mean2,s=10) + + # loss matrix + M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2) + #M/=M.max() + + #%% + + print('Computing {} EMD '.format(1)) + + # emd loss 1 proc + ot.tic() + G, alpha, beta = ot.emd(a,b,M, dual_variables=True) + ot.toc('1 proc : {} s') + + cost1 = (G * M).sum() + cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) + + # emd loss 1 proc + ot.tic() + cost_emd2 = ot.emd2(a,b,M) + ot.toc('1 proc : {} s') + + ot.tic() + G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) + ot.toc('1 proc : {} s') + + cost2 = (G2 * M.T).sum() + + M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) + + # Check that both cost computations are equivalent + np.testing.assert_almost_equal(cost1, cost_emd2) + # Check that dual and primal cost are equal + np.testing.assert_almost_equal(cost1, cost_dual) + # Check symmetry + np.testing.assert_almost_equal(cost1, cost2) + # Check with closed-form solution for gaussians + np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) + + [ind1, ind2] = np.nonzero(G) + + # Check that reduced cost is zero on transport arcs + np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) -- cgit v1.2.3