diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-05-11 17:24:09 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-05-11 17:24:09 +0200 |
commit | fdb2f3af19d04872bafa0d9ec5563732e1d6209b (patch) | |
tree | bc9d94d0d83126e68e633ce3030f801007426fe5 | |
parent | 36f4f7ed2116841d7fe9514ee250bbf16e77b72d (diff) |
add test for barycenter
-rw-r--r-- | ot/lp/cvx.py | 12 | ||||
-rw-r--r-- | test/test_gpu.py | 2 | ||||
-rw-r--r-- | test/test_ot.py | 35 |
3 files changed, 42 insertions, 7 deletions
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index c62da6a..fe9ac76 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -39,7 +39,9 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` The linear program is solved using the interior point solver from scipy.optimize. - If cvxopt solver if installed it can use cvxopt. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). Parameters ---------- @@ -114,14 +116,14 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po A_eq = sps.vstack((A_eq1, A_eq2)) b_eq = np.concatenate((b_eq1, b_eq2)) - if not cvxopt or solver in ['interior-point']: + if not cvxopt or solver in ['interior-point']: # cvxopt not installed or interior point if solver is None: solver = 'interior-point' options = {'sparse': True, 'disp': verbose} - sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options) x = sol.x b = x[-n:] @@ -131,8 +133,8 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po h = np.zeros((n_distributions * n2 + n)) G = -sps.eye(n_distributions * n2 + n) - sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), - A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver) x = np.array(sol['x']) diff --git a/test/test_gpu.py b/test/test_gpu.py index 615c2a7..1e97c45 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -76,4 +76,4 @@ def test_gpu_sinkhorn_lpl1(): time3 - time2)) describe_res(G2) - np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3) diff --git a/test/test_ot.py b/test/test_ot.py index ea6d9dc..bf23e8c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -10,7 +10,7 @@ import numpy as np import ot from ot.datasets import get_1D_gauss as gauss - +import pytest def test_doctest(): import doctest @@ -117,6 +117,39 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) +def test_lp_barycenter(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5]) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + +@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +def test_lp_barycenter_cvxopt(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5],solver=None) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + def test_warnings(): n = 100 # nb bins m = 100 # nb bins |