summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
commit12d9b3ff72e9669ccc0162e82b7a33beb51d3e25 (patch)
tree72a2908e9d0e67f7e8499d7ef9aca1246528a980 /test
parentf8c1c8740f9974dcf4aaf191851d62149dceb91c (diff)
Return dual variables in an optional dictionary
Also removed some code duplication
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py20
1 files changed, 7 insertions, 13 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 8a19cf6..78f64ab 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -124,27 +124,26 @@ def test_warnings():
# %%
print('Computing {} EMD '.format(1))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
print('Computing {} EMD '.format(1))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1)
+ G = ot.emd(a, b, M, numItermax=1)
# Verify some things
assert "numItermax" in str(w[-1].message)
assert len(w) == 1
# Trigger a warning.
a[0]=100
print('Computing {} EMD '.format(2))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ G = ot.emd(a, b, M)
# Verify some things
assert "infeasible" in str(w[-1].message)
assert len(w) == 2
# Trigger a warning.
a[0]=-1
print('Computing {} EMD '.format(2))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ G = ot.emd(a, b, M)
# Verify some things
assert "infeasible" in str(w[-1].message)
assert len(w) == 3
@@ -176,16 +175,11 @@ def test_dual_variables():
# emd loss 1 proc
ot.tic()
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ G, log = ot.emd(a, b, M, log=True)
ot.toc('1 proc : {} s')
cost1 = (G * M).sum()
- cost_dual = np.vdot(a, alpha) + np.vdot(b, beta)
-
- # emd loss 1 proc
- ot.tic()
- cost_emd2 = ot.emd2(a, b, M)
- ot.toc('1 proc : {} s')
+ cost_dual = np.vdot(a, log['u']) + np.vdot(b, log['v'])
ot.tic()
G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
@@ -194,7 +188,7 @@ def test_dual_variables():
cost2 = (G2 * M.T).sum()
# Check that both cost computations are equivalent
- np.testing.assert_almost_equal(cost1, cost_emd2)
+ np.testing.assert_almost_equal(cost1, log['cost'])
# Check that dual and primal cost are equal
np.testing.assert_almost_equal(cost1, cost_dual)
# Check symmetry
@@ -205,5 +199,5 @@ def test_dual_variables():
[ind1, ind2] = np.nonzero(G)
# Check that reduced cost is zero on transport arcs
- np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2],
+ np.testing.assert_array_almost_equal((M - log['u'].reshape(-1, 1) - log['v'].reshape(1, -1))[ind1, ind2],
np.zeros(ind1.size))