diff options
author | arolet <antoine.rolet@gmail.com> | 2017-07-21 13:34:09 +0900 |
---|---|---|
committer | arolet <antoine.rolet@gmail.com> | 2017-07-21 13:34:09 +0900 |
commit | c1980a414c879dd1bc1d8904fd43426326741385 (patch) | |
tree | 2513d6ed5ee49d351652e0f989d83cb8298e72fa /test | |
parent | db2a70b1f5146d6374af57f4bea66ab95b202e77 (diff) |
Added and passed tests for dual variables
Diffstat (limited to 'test')
-rw-r--r-- | test/test_emd.py | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/test/test_emd.py b/test/test_emd.py index 4757cd1..3bf6fa2 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import numpy as np -import pylab as pl import ot from ot.datasets import get_1D_gauss as gauss @@ -16,8 +15,6 @@ m=6000 # nb bins mean1 = 1000 mean2 = 1100 -tol = 1e-6 - # bin positions x=np.arange(n,dtype=np.float64) y=np.arange(m,dtype=np.float64) @@ -38,10 +35,11 @@ print('Computing {} EMD '.format(1)) # emd loss 1 proc ot.tic() -G = ot.emd(a,b,M) +G, alpha, beta = ot.emd(a,b,M, dual_variables=True) ot.toc('1 proc : {} s') cost1 = (G * M).sum() +cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) # emd loss 1 proc ot.tic() @@ -49,11 +47,23 @@ cost_emd2 = ot.emd2(a,b,M) ot.toc('1 proc : {} s') ot.tic() -G = ot.emd(b, a, np.ascontiguousarray(M.T)) +G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') -cost2 = (G * M.T).sum() +cost2 = (G2 * M.T).sum() + +M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) + +# Check that both cost computations are equivalent +np.testing.assert_almost_equal(cost1, cost_emd2) +# Check that dual and primal cost are equal +np.testing.assert_almost_equal(cost1, cost_dual) +# Check symmetry +np.testing.assert_almost_equal(cost1, cost2) +# Check with closed-form solution for gaussians +np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) + +[ind1, ind2] = np.nonzero(G) -assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol -assert np.abs(cost1-cost2)/np.abs(cost1) < tol -assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol +# Check that reduced cost is zero on transport arcs +np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size))
\ No newline at end of file |