summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-11 17:24:09 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-11 17:24:09 +0200
commitfdb2f3af19d04872bafa0d9ec5563732e1d6209b (patch)
treebc9d94d0d83126e68e633ce3030f801007426fe5
parent36f4f7ed2116841d7fe9514ee250bbf16e77b72d (diff)
add test for barycenter
-rw-r--r--ot/lp/cvx.py12
-rw-r--r--test/test_gpu.py2
-rw-r--r--test/test_ot.py35
3 files changed, 42 insertions, 7 deletions
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index c62da6a..fe9ac76 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -39,7 +39,9 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
The linear program is solved using the interior point solver from scipy.optimize.
- If cvxopt solver if installed it can use cvxopt.
+ If cvxopt solver if installed it can use cvxopt
+
+ Note that this problem do not scale well (both in memory and computational time).
Parameters
----------
@@ -114,14 +116,14 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
A_eq = sps.vstack((A_eq1, A_eq2))
b_eq = np.concatenate((b_eq1, b_eq2))
- if not cvxopt or solver in ['interior-point']:
+ if not cvxopt or solver in ['interior-point']:
# cvxopt not installed or interior point
if solver is None:
solver = 'interior-point'
options = {'sparse': True, 'disp': verbose}
- sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
+ sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
options=options)
x = sol.x
b = x[-n:]
@@ -131,8 +133,8 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
h = np.zeros((n_distributions * n2 + n))
G = -sps.eye(n_distributions * n2 + n)
- sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h),
- A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq),
+ sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h),
+ A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq),
solver=solver)
x = np.array(sol['x'])
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 615c2a7..1e97c45 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -76,4 +76,4 @@ def test_gpu_sinkhorn_lpl1():
time3 - time2))
describe_res(G2)
- np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5)
+ np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3)
diff --git a/test/test_ot.py b/test/test_ot.py
index ea6d9dc..bf23e8c 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -10,7 +10,7 @@ import numpy as np
import ot
from ot.datasets import get_1D_gauss as gauss
-
+import pytest
def test_doctest():
import doctest
@@ -117,6 +117,39 @@ def test_emd2_multi():
np.testing.assert_allclose(emd1, emdn)
+def test_lp_barycenter():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5])
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
+def test_lp_barycenter_cvxopt():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5],solver=None)
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
def test_warnings():
n = 100 # nb bins
m = 100 # nb bins