summaryrefslogtreecommitdiff
path: root/ot/lp/cvx.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-11 16:58:06 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-11 16:58:06 +0200
commit3aee908ad42d65897f1916de6eab84921ac94a10 (patch)
treebea38fb348e55118f36d7c4457310922a34da5f7 /ot/lp/cvx.py
parent060d9046b291c76244deab2d78ee8356a294e91f (diff)
pep8
Diffstat (limited to 'ot/lp/cvx.py')
-rw-r--r--ot/lp/cvx.py104
1 files changed, 53 insertions, 51 deletions
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 4d08916..93097d1 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -15,7 +15,8 @@ try:
import cvxopt
from cvxopt import solvers, matrix, sparse, spmatrix
except ImportError:
- cvxopt=False
+ cvxopt = False
+
def scipy_sparse_to_spmatrix(A):
"""Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
@@ -23,7 +24,8 @@ def scipy_sparse_to_spmatrix(A):
SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
return SP
-def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'):
+
+def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
@@ -36,7 +38,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
- :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- The linear program is solved using the default cvxopt solver if installed.
+ The linear program is solved using the default cvxopt solver if installed.
If cvxopt is not installed it uses the lp solver from scipy.optimize.
Parameters
@@ -48,13 +50,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
reg : float
Regularization term >0
weights : np.ndarray (n,)
- Weights of each histogram i_i on the simplex
+ Weights of each histogram i_i on the simplex
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
solver : string, optional
- the solver used, default 'interior-point' use the lp solver from
+ the solver used, default 'interior-point' use the lp solver from
scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
Returns
@@ -78,61 +80,61 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
weights = np.ones(A.shape[1]) / A.shape[1]
else:
assert(len(weights) == A.shape[1])
-
- n_distributions=A.shape[1]
- n=A.shape[0]
-
- n2=n*n
- c=np.zeros((0))
- b_eq1=np.zeros((0))
+
+ n_distributions = A.shape[1]
+ n = A.shape[0]
+
+ n2 = n * n
+ c = np.zeros((0))
+ b_eq1 = np.zeros((0))
for i in range(n_distributions):
- c=np.concatenate((c,M.ravel()*weights[i]))
- b_eq1=np.concatenate((b_eq1,A[:,i]))
- c=np.concatenate((c,np.zeros(n)))
-
- lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)]
+ c = np.concatenate((c, M.ravel() * weights[i]))
+ b_eq1 = np.concatenate((b_eq1, A[:, i]))
+ c = np.concatenate((c, np.zeros(n)))
+
+ lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
# row constraints
- A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n))))
-
+ A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))))
+
# columns constraints
- lst_idiag2=[]
- lst_eye=[]
+ lst_idiag2 = []
+ lst_eye = []
for i in range(n_distributions):
- if i==0:
- lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n)))
+ if i == 0:
+ lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
lst_eye.append(-sps.eye(n))
else:
- lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n)))
- lst_eye.append(-sps.eye(n-1,n))
-
- A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye)))
- b_eq2=np.zeros((A_eq2.shape[0]))
-
+ lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
+ lst_eye.append(-sps.eye(n - 1, n))
+
+ A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
+ b_eq2 = np.zeros((A_eq2.shape[0]))
+
# full problem
- A_eq=sps.vstack((A_eq1,A_eq2))
- b_eq=np.concatenate((b_eq1,b_eq2))
-
- if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point
-
+ A_eq = sps.vstack((A_eq1, A_eq2))
+ b_eq = np.concatenate((b_eq1, b_eq2))
+
+ if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point
+
if solver is None:
- solver='interior-point'
-
- options={'sparse':True,'disp': verbose}
- sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options)
- x=sol.x
- b=x[-n:]
-
+ solver = 'interior-point'
+
+ options = {'sparse': True, 'disp': verbose}
+ sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options)
+ x = sol.x
+ b = x[-n:]
+
else:
-
- h=np.zeros((n_distributions*n2+n))
- G=-sps.eye(n_distributions*n2+n)
-
- sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver)
-
- x=np.array(sol['x'])
- b=x[-n:].ravel()
-
+
+ h = np.zeros((n_distributions * n2 + n))
+ G = -sps.eye(n_distributions * n2 + n)
+
+ sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver)
+
+ x = np.array(sol['x'])
+ b = x[-n:].ravel()
+
if log:
return b, sol
else:
- return b \ No newline at end of file
+ return b