summaryrefslogtreecommitdiff
path: root/ot/lp/cvx.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-11 16:56:47 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-11 16:56:47 +0200
commit060d9046b291c76244deab2d78ee8356a294e91f (patch)
tree90f775960e4e07c47acc41d1fb5cace61606e1cb /ot/lp/cvx.py
parentbe8817730c7996052e84d21ba08cf60f59020935 (diff)
add cvx barycenter solver
Diffstat (limited to 'ot/lp/cvx.py')
-rw-r--r--ot/lp/cvx.py138
1 files changed, 138 insertions, 0 deletions
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
new file mode 100644
index 0000000..4d08916
--- /dev/null
+++ b/ot/lp/cvx.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+"""
+LP solvers for optimal transport using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy as sp
+import scipy.sparse as sps
+
+try:
+ import cvxopt
+ from cvxopt import solvers, matrix, sparse, spmatrix
+except ImportError:
+ cvxopt=False
+
+def scipy_sparse_to_spmatrix(A):
+ """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
+ coo = A.tocoo()
+ SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
+ return SP
+
+def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem [16]:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+
+ The linear program is solved using the default cvxopt solver if installed.
+ If cvxopt is not installed it uses the lp solver from scipy.optimize.
+
+ Parameters
+ ----------
+ A : np.ndarray (d,n)
+ n training distributions of size d
+ M : np.ndarray (d,d)
+ loss matrix for OT
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each histogram i_i on the simplex
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ solver : string, optional
+ the solver used, default 'interior-point' use the lp solver from
+ scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
+
+ Returns
+ -------
+ a : (d,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
+
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[1]) / A.shape[1]
+ else:
+ assert(len(weights) == A.shape[1])
+
+ n_distributions=A.shape[1]
+ n=A.shape[0]
+
+ n2=n*n
+ c=np.zeros((0))
+ b_eq1=np.zeros((0))
+ for i in range(n_distributions):
+ c=np.concatenate((c,M.ravel()*weights[i]))
+ b_eq1=np.concatenate((b_eq1,A[:,i]))
+ c=np.concatenate((c,np.zeros(n)))
+
+ lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)]
+ # row constraints
+ A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n))))
+
+ # columns constraints
+ lst_idiag2=[]
+ lst_eye=[]
+ for i in range(n_distributions):
+ if i==0:
+ lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n)))
+ lst_eye.append(-sps.eye(n))
+ else:
+ lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n)))
+ lst_eye.append(-sps.eye(n-1,n))
+
+ A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye)))
+ b_eq2=np.zeros((A_eq2.shape[0]))
+
+ # full problem
+ A_eq=sps.vstack((A_eq1,A_eq2))
+ b_eq=np.concatenate((b_eq1,b_eq2))
+
+ if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point
+
+ if solver is None:
+ solver='interior-point'
+
+ options={'sparse':True,'disp': verbose}
+ sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options)
+ x=sol.x
+ b=x[-n:]
+
+ else:
+
+ h=np.zeros((n_distributions*n2+n))
+ G=-sps.eye(n_distributions*n2+n)
+
+ sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver)
+
+ x=np.array(sol['x'])
+ b=x[-n:].ravel()
+
+ if log:
+ return b, sol
+ else:
+ return b \ No newline at end of file