From 3aee908ad42d65897f1916de6eab84921ac94a10 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 11 May 2018 16:58:06 +0200 Subject: pep8 --- ot/bregman.py | 2 +- ot/lp/cvx.py | 104 ++++++++++++++++++++++++++++++---------------------------- 2 files changed, 54 insertions(+), 52 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 9c84aed..e788ef5 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -845,7 +845,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram i_i on the simplex + Weights of each histogram i_i on the simplex numItermax : int, optional Max number of iterations stopThr : float, optional diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 4d08916..93097d1 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -15,7 +15,8 @@ try: import cvxopt from cvxopt import solvers, matrix, sparse, spmatrix except ImportError: - cvxopt=False + cvxopt = False + def scipy_sparse_to_spmatrix(A): """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix""" @@ -23,7 +24,8 @@ def scipy_sparse_to_spmatrix(A): SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape) return SP -def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'): + +def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): """Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -36,7 +38,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - The linear program is solved using the default cvxopt solver if installed. + The linear program is solved using the default cvxopt solver if installed. If cvxopt is not installed it uses the lp solver from scipy.optimize. Parameters @@ -48,13 +50,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi reg : float Regularization term >0 weights : np.ndarray (n,) - Weights of each histogram i_i on the simplex + Weights of each histogram i_i on the simplex verbose : bool, optional Print information along iterations log : bool, optional record log if True solver : string, optional - the solver used, default 'interior-point' use the lp solver from + the solver used, default 'interior-point' use the lp solver from scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt. Returns @@ -78,61 +80,61 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi weights = np.ones(A.shape[1]) / A.shape[1] else: assert(len(weights) == A.shape[1]) - - n_distributions=A.shape[1] - n=A.shape[0] - - n2=n*n - c=np.zeros((0)) - b_eq1=np.zeros((0)) + + n_distributions = A.shape[1] + n = A.shape[0] + + n2 = n * n + c = np.zeros((0)) + b_eq1 = np.zeros((0)) for i in range(n_distributions): - c=np.concatenate((c,M.ravel()*weights[i])) - b_eq1=np.concatenate((b_eq1,A[:,i])) - c=np.concatenate((c,np.zeros(n))) - - lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)] + c = np.concatenate((c, M.ravel() * weights[i])) + b_eq1 = np.concatenate((b_eq1, A[:, i])) + c = np.concatenate((c, np.zeros(n))) + + lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] # row constraints - A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n)))) - + A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))) + # columns constraints - lst_idiag2=[] - lst_eye=[] + lst_idiag2 = [] + lst_eye = [] for i in range(n_distributions): - if i==0: - lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n))) + if i == 0: + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n))) lst_eye.append(-sps.eye(n)) else: - lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n))) - lst_eye.append(-sps.eye(n-1,n)) - - A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye))) - b_eq2=np.zeros((A_eq2.shape[0])) - + lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n))) + lst_eye.append(-sps.eye(n - 1, n)) + + A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye))) + b_eq2 = np.zeros((A_eq2.shape[0])) + # full problem - A_eq=sps.vstack((A_eq1,A_eq2)) - b_eq=np.concatenate((b_eq1,b_eq2)) - - if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point - + A_eq = sps.vstack((A_eq1, A_eq2)) + b_eq = np.concatenate((b_eq1, b_eq2)) + + if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point + if solver is None: - solver='interior-point' - - options={'sparse':True,'disp': verbose} - sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options) - x=sol.x - b=x[-n:] - + solver = 'interior-point' + + options = {'sparse': True, 'disp': verbose} + sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options) + x = sol.x + b = x[-n:] + else: - - h=np.zeros((n_distributions*n2+n)) - G=-sps.eye(n_distributions*n2+n) - - sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver) - - x=np.array(sol['x']) - b=x[-n:].ravel() - + + h = np.zeros((n_distributions * n2 + n)) + G = -sps.eye(n_distributions * n2 + n) + + sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver) + + x = np.array(sol['x']) + b = x[-n:].ravel() + if log: return b, sol else: - return b \ No newline at end of file + return b -- cgit v1.2.3