summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-29 16:16:41 +0200
committerGitHub <noreply@github.com>2018-05-29 16:16:41 +0200
commit90efa5a8b189214d1aeb81920b2bb04ce0c261ca (patch)
tree62e2f1a3cca2f4885e8c0e2a0b135a5f574d6a8c
parentec79b791f4f4a62f7c04b7bbf14fe2f5dcbb4c75 (diff)
parent54f0b47e55c966d5492e4ce19ec4e704ef3278d6 (diff)
Merge pull request #47 from rflamary/bary
LP Wasserstein barycenter with scipy linear solver and/or cvxopt
-rw-r--r--Makefile4
-rw-r--r--README.md5
-rw-r--r--examples/plot_barycenter_lp_vs_entropic.py292
-rw-r--r--ot/bregman.py4
-rw-r--r--ot/lp/__init__.py3
-rw-r--r--ot/lp/cvx.py146
-rw-r--r--requirements.txt2
-rw-r--r--test/test_gpu.py2
-rw-r--r--test/test_ot.py36
9 files changed, 488 insertions, 6 deletions
diff --git a/Makefile b/Makefile
index 0de3fe9..1abc6e9 100644
--- a/Makefile
+++ b/Makefile
@@ -58,9 +58,9 @@ notebook :
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
autopep8 :
- autopep8 -ir test ot examples
+ autopep8 -ir test ot examples --jobs -1
aautopep8 :
- autopep8 -air test ot examples
+ autopep8 -air test ot examples --jobs -1
FORCE :
diff --git a/README.md b/README.md
index 6b7cff0..466c09c 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,8 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
-* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat).
+* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
+* Non regularized Wasserstein barycenters [16] with LP solver.
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -210,3 +211,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43.
[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) .
+
+[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py
new file mode 100644
index 0000000..6936bbb
--- /dev/null
+++ b/examples/plot_barycenter_lp_vs_entropic.py
@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+"""
+=================================================================================
+1D Wasserstein barycenter comparison between exact LP and entropic regularization
+=================================================================================
+
+This example illustrates the computation of regularized Wasserstein Barycenter
+as proposed in [3] and exact LP barycenters using standard LP solver.
+
+It reproduces approximately Figure 3.1 and 3.2 from the following paper:
+Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
+Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection # noqa
+
+#import ot.lp.cvx as cvx
+
+#
+# Generate data
+# -------------
+
+#%% parameters
+
+problems = []
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+# Gaussian distributions
+a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+#
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+#
+# Barycenter computation
+# ----------------------
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+#%% parameters
+
+a1 = 1.0 * (x > 10) * (x < 50)
+a2 = 1.0 * (x > 60) * (x < 80)
+
+a1 /= a1.sum()
+a2 /= a2.sum()
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+#
+# Barycenter computation
+# ----------------------
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+#%% parameters
+
+a1 = np.zeros(n)
+a2 = np.zeros(n)
+
+a1[10] = .25
+a1[20] = .5
+a1[30] = .25
+a2[80] = 1
+
+
+a1 /= a1.sum()
+a2 /= a2.sum()
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+#
+# Barycenter computation
+# ----------------------
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+
+#
+# Final figure
+# ------------
+#
+
+#%% plot
+
+nbm = len(problems)
+nbm2 = (nbm // 2)
+
+
+pl.figure(2, (20, 6))
+pl.clf()
+
+for i in range(nbm):
+
+ A = problems[i][0]
+ bary_l2 = problems[i][1][0]
+ bary_wass = problems[i][1][1]
+ bary_wass2 = problems[i][1][2]
+
+ pl.subplot(2, nbm, 1 + i)
+ for j in range(n_distributions):
+ pl.plot(x, A[:, j])
+ if i == nbm2:
+ pl.title('Distributions')
+ pl.xticks(())
+ pl.yticks(())
+
+ pl.subplot(2, nbm, 1 + i + nbm)
+
+ pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ if i == nbm - 1:
+ pl.legend()
+ if i == nbm2:
+ pl.title('Barycenters')
+
+ pl.xticks(())
+ pl.yticks(())
diff --git a/ot/bregman.py b/ot/bregman.py
index 07b8660..b017c1a 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -839,11 +839,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
A : np.ndarray (d,n)
- n training distributions of size d
+ n training distributions a_i of size d
M : np.ndarray (d,d)
loss matrix for OT
reg : float
Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 6371feb..5dda82a 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -11,9 +11,12 @@ import multiprocessing
import numpy as np
+from .import cvx
+
# import compiled emd
from .emd_wrap import emd_c, check_result
from ..utils import parmap
+from .cvx import barycenter
def emd(a, b, M, numItermax=100000, log=False):
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
new file mode 100644
index 0000000..c8c75bc
--- /dev/null
+++ b/ot/lp/cvx.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+"""
+LP solvers for optimal transport using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy as sp
+import scipy.sparse as sps
+
+try:
+ import cvxopt
+ from cvxopt import solvers, matrix, spmatrix
+except ImportError:
+ cvxopt = False
+
+
+def scipy_sparse_to_spmatrix(A):
+ """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
+ coo = A.tocoo()
+ SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
+ return SP
+
+
+def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem [16]:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+
+ The linear program is solved using the interior point solver from scipy.optimize.
+ If cvxopt solver if installed it can use cvxopt
+
+ Note that this problem do not scale well (both in memory and computational time).
+
+ Parameters
+ ----------
+ A : np.ndarray (d,n)
+ n training distributions a_i of size d
+ M : np.ndarray (d,d)
+ loss matrix for OT
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ solver : string, optional
+ the solver used, default 'interior-point' use the lp solver from
+ scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
+
+ Returns
+ -------
+ a : (d,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
+
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[1]) / A.shape[1]
+ else:
+ assert(len(weights) == A.shape[1])
+
+ n_distributions = A.shape[1]
+ n = A.shape[0]
+
+ n2 = n * n
+ c = np.zeros((0))
+ b_eq1 = np.zeros((0))
+ for i in range(n_distributions):
+ c = np.concatenate((c, M.ravel() * weights[i]))
+ b_eq1 = np.concatenate((b_eq1, A[:, i]))
+ c = np.concatenate((c, np.zeros(n)))
+
+ lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
+ # row constraints
+ A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))))
+
+ # columns constraints
+ lst_idiag2 = []
+ lst_eye = []
+ for i in range(n_distributions):
+ if i == 0:
+ lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
+ lst_eye.append(-sps.eye(n))
+ else:
+ lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
+ lst_eye.append(-sps.eye(n - 1, n))
+
+ A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
+ b_eq2 = np.zeros((A_eq2.shape[0]))
+
+ # full problem
+ A_eq = sps.vstack((A_eq1, A_eq2))
+ b_eq = np.concatenate((b_eq1, b_eq2))
+
+ if not cvxopt or solver in ['interior-point']:
+ # cvxopt not installed or interior point
+
+ if solver is None:
+ solver = 'interior-point'
+
+ options = {'sparse': True, 'disp': verbose}
+ sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
+ options=options)
+ x = sol.x
+ b = x[-n:]
+
+ else:
+
+ h = np.zeros((n_distributions * n2 + n))
+ G = -sps.eye(n_distributions * n2 + n)
+
+ sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h),
+ A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq),
+ solver=solver)
+
+ x = np.array(sol['x'])
+ b = x[-n:].ravel()
+
+ if log:
+ return b, sol
+ else:
+ return b
diff --git a/requirements.txt b/requirements.txt
index 37d62cc..97d165b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
numpy
-scipy
+scipy>=1.0
cython
matplotlib
sphinx-gallery
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 615c2a7..1e97c45 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -76,4 +76,4 @@ def test_gpu_sinkhorn_lpl1():
time3 - time2))
describe_res(G2)
- np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5)
+ np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3)
diff --git a/test/test_ot.py b/test/test_ot.py
index ea6d9dc..cc25bf4 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -10,6 +10,7 @@ import numpy as np
import ot
from ot.datasets import get_1D_gauss as gauss
+import pytest
def test_doctest():
@@ -117,6 +118,41 @@ def test_emd2_multi():
np.testing.assert_allclose(emd1, emdn)
+def test_lp_barycenter():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5])
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+
+@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
+def test_lp_barycenter_cvxopt():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5], solver=None)
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+
def test_warnings():
n = 100 # nb bins
m = 100 # nb bins