diff options
-rw-r--r-- | Makefile | 4 | ||||
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | examples/plot_barycenter_lp_vs_entropic.py | 292 | ||||
-rw-r--r-- | ot/bregman.py | 4 | ||||
-rw-r--r-- | ot/lp/__init__.py | 3 | ||||
-rw-r--r-- | ot/lp/cvx.py | 146 | ||||
-rw-r--r-- | requirements.txt | 2 | ||||
-rw-r--r-- | test/test_gpu.py | 2 | ||||
-rw-r--r-- | test/test_ot.py | 36 |
9 files changed, 488 insertions, 6 deletions
@@ -58,9 +58,9 @@ notebook : ipython notebook --matplotlib=inline --notebook-dir=notebooks/ autopep8 : - autopep8 -ir test ot examples + autopep8 -ir test ot examples --jobs -1 aautopep8 : - autopep8 -air test ot examples + autopep8 -air test ot examples --jobs -1 FORCE : @@ -14,7 +14,8 @@ This open source Python library provide several solvers for optimization problem It provides the following solvers: * OT Network Flow solver for the linear program/ Earth Movers Distance [1]. -* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat). +* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat). +* Non regularized Wasserstein barycenters [16] with LP solver. * Bregman projections for Wasserstein barycenter [3] and unmixing [4]. * Optimal transport for domain adaptation with group lasso regularization [5] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. @@ -210,3 +211,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. [15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . + +[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py new file mode 100644 index 0000000..6936bbb --- /dev/null +++ b/examples/plot_barycenter_lp_vs_entropic.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +""" +================================================================================= +1D Wasserstein barycenter comparison between exact LP and entropic regularization +================================================================================= + +This example illustrates the computation of regularized Wasserstein Barycenter +as proposed in [3] and exact LP barycenters using standard LP solver. + +It reproduces approximately Figure 3.1 and 3.2 from the following paper: +Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational +Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343. + +[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). +Iterative Bregman projections for regularized transportation problems +SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection # noqa + +#import ot.lp.cvx as cvx + +# +# Generate data +# ------------- + +#%% parameters + +problems = [] + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +# Gaussian distributions +a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.get_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +# +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +#%% parameters + +a1 = 1.0 * (x > 10) * (x < 50) +a2 = 1.0 * (x > 60) * (x < 80) + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +#%% parameters + +a1 = np.zeros(n) +a2 = np.zeros(n) + +a1[10] = .25 +a1[20] = .5 +a1[30] = .25 +a2[80] = 1 + + +a1 /= a1.sum() +a2 /= a2.sum() + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +# +# Barycenter computation +# ---------------------- + +#%% barycenter computation + +alpha = 0.5 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +ot.tic() +bary_wass = ot.bregman.barycenter(A, M, reg, weights) +ot.toc() + + +ot.tic() +bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True) +ot.toc() + + +problems.append([A, [bary_l2, bary_wass, bary_wass2]]) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') +pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + + +# +# Final figure +# ------------ +# + +#%% plot + +nbm = len(problems) +nbm2 = (nbm // 2) + + +pl.figure(2, (20, 6)) +pl.clf() + +for i in range(nbm): + + A = problems[i][0] + bary_l2 = problems[i][1][0] + bary_wass = problems[i][1][1] + bary_wass2 = problems[i][1][2] + + pl.subplot(2, nbm, 1 + i) + for j in range(n_distributions): + pl.plot(x, A[:, j]) + if i == nbm2: + pl.title('Distributions') + pl.xticks(()) + pl.yticks(()) + + pl.subplot(2, nbm, 1 + i + nbm) + + pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') + pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') + pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') + if i == nbm - 1: + pl.legend() + if i == nbm2: + pl.title('Barycenters') + + pl.xticks(()) + pl.yticks(()) diff --git a/ot/bregman.py b/ot/bregman.py index 07b8660..b017c1a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -839,11 +839,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, Parameters ---------- A : np.ndarray (d,n) - n training distributions of size d + n training distributions a_i of size d M : np.ndarray (d,d) loss matrix for OT reg : float Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6371feb..5dda82a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,9 +11,12 @@ import multiprocessing import numpy as np +from .import cvx + # import compiled emd from .emd_wrap import emd_c, check_result from ..utils import parmap +from .cvx import barycenter def emd(a, b, M, numItermax=100000, log=False): diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py new file mode 100644 index 0000000..c8c75bc --- /dev/null +++ b/ot/lp/cvx.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +""" +LP solvers for optimal transport using cvxopt +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import scipy as sp +import scipy.sparse as sps + +try: + import cvxopt + from cvxopt import solvers, matrix, spmatrix +except ImportError: + cvxopt = False + + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + + +def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): + """Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert(len(weights) == A.shape[1]) + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) + for i in range(n_distributions): + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] + # row constraints + A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))) + + # columns constraints + lst_idiag2 = [] + lst_eye = [] + for i in range(n_distributions): + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + + # full problem + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ['interior-point']: + # cvxopt not installed or interior point + + if solver is None: + solver = 'interior-point' + + options = {'sparse': True, 'disp': verbose} + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, + options=options) + x = sol.x + b = x[-n:] + + else: + + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), + solver=solver) + + x = np.array(sol['x']) + b = x[-n:].ravel() + + if log: + return b, sol + else: + return b diff --git a/requirements.txt b/requirements.txt index 37d62cc..97d165b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -scipy +scipy>=1.0 cython matplotlib sphinx-gallery diff --git a/test/test_gpu.py b/test/test_gpu.py index 615c2a7..1e97c45 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -76,4 +76,4 @@ def test_gpu_sinkhorn_lpl1(): time3 - time2)) describe_res(G2) - np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3) diff --git a/test/test_ot.py b/test/test_ot.py index ea6d9dc..cc25bf4 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -10,6 +10,7 @@ import numpy as np import ot from ot.datasets import get_1D_gauss as gauss +import pytest def test_doctest(): @@ -117,6 +118,41 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) +def test_lp_barycenter(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5]) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + + +@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +def test_lp_barycenter_cvxopt(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5], solver=None) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + + def test_warnings(): n = 100 # nb bins m = 100 # nb bins |