diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2018-05-29 16:16:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-29 16:16:41 +0200 |
commit | 90efa5a8b189214d1aeb81920b2bb04ce0c261ca (patch) | |
tree | 62e2f1a3cca2f4885e8c0e2a0b135a5f574d6a8c /ot | |
parent | ec79b791f4f4a62f7c04b7bbf14fe2f5dcbb4c75 (diff) | |
parent | 54f0b47e55c966d5492e4ce19ec4e704ef3278d6 (diff) |
Merge pull request #47 from rflamary/bary
LP Wasserstein barycenter with scipy linear solver and/or cvxopt
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 4 | ||||
-rw-r--r-- | ot/lp/__init__.py | 3 | ||||
-rw-r--r-- | ot/lp/cvx.py | 146 |
3 files changed, 152 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 07b8660..b017c1a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -839,11 +839,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, Parameters ---------- A : np.ndarray (d,n) - n training distributions of size d + n training distributions a_i of size d M : np.ndarray (d,d) loss matrix for OT reg : float Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6371feb..5dda82a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,9 +11,12 @@ import multiprocessing import numpy as np +from .import cvx + # import compiled emd from .emd_wrap import emd_c, check_result from ..utils import parmap +from .cvx import barycenter def emd(a, b, M, numItermax=100000, log=False): diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py new file mode 100644 index 0000000..c8c75bc --- /dev/null +++ b/ot/lp/cvx.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +""" +LP solvers for optimal transport using cvxopt +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import scipy as sp +import scipy.sparse as sps + +try: + import cvxopt + from cvxopt import solvers, matrix, spmatrix +except ImportError: + cvxopt = False + + +def scipy_sparse_to_spmatrix(A): + """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" + coo = A.tocoo() + SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) + return SP + + +def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): + """Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem [16]: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + + The linear program is solved using the interior point solver from scipy.optimize. + If cvxopt solver if installed it can use cvxopt + + Note that this problem do not scale well (both in memory and computational time). + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + solver : string, optional + the solver used, default 'interior-point' use the lp solver from + scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. + + Returns + ------- + a : (d,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. + + + + """ + + if weights is None: + weights = np.ones(A.shape[1]) / A.shape[1] + else: + assert(len(weights) == A.shape[1]) + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) + for i in range(n_distributions): + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] + # row constraints + A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))) + + # columns constraints + lst_idiag2 = [] + lst_eye = [] + for i in range(n_distributions): + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) + lst_eye.append(-sps.eye(n)) + else: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + + # full problem + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ['interior-point']: + # cvxopt not installed or interior point + + if solver is None: + solver = 'interior-point' + + options = {'sparse': True, 'disp': verbose} + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, + options=options) + x = sol.x + b = x[-n:] + + else: + + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), + solver=solver) + + x = np.array(sol['x']) + b = x[-n:].ravel() + + if log: + return b, sol + else: + return b |