From e585d64dd08e5367350e70f23e81f9fd2d676a6b Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 31 May 2018 11:33:08 +0200 Subject: pep8 --- ot/smooth.py | 67 +++++++++++++++++++++++++++++------------------------------- 1 file changed, 32 insertions(+), 35 deletions(-) (limited to 'ot') diff --git a/ot/smooth.py b/ot/smooth.py index f8bb20a..f4f4306 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -1,4 +1,4 @@ -#Copyright (c) 2018, Mathieu Blondel +#Copyright (c) 2018, Mathieu Blondel #All rights reserved. # #Redistribution and use in source and binary forms, with or without @@ -22,7 +22,7 @@ #OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF #THE POSSIBILITY OF SUCH DAMAGE. -# Author: Mathieu Blondel +# Author: Mathieu Blondel # Remi Flamary """ @@ -39,7 +39,7 @@ from scipy.optimize import minimize def projection_simplex(V, z=1, axis=None): """ Projection of x onto the simplex, scaled by z - + P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2 z: float or array If array, len(z) must be compatible with V @@ -65,19 +65,19 @@ def projection_simplex(V, z=1, axis=None): else: V = V.ravel().reshape(1, -1) return projection_simplex(V, z, axis=1).ravel() - + class Regularization(object): def __init__(self, gamma=1.0): """ - + Parameters ---------- gamma: float Regularization parameter. We recover unregularized OT when gamma -> 0. - + """ self.gamma = gamma @@ -85,12 +85,12 @@ class Regularization(object): """ Compute delta_Omega(X[:, j]) for each X[:, j]. delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y). - + Parameters ---------- X: array, shape = len(a) x len(b) Input array. - + Returns ------- v: array, len(b) @@ -104,12 +104,12 @@ class Regularization(object): """ Compute max_Omega_j(X[:, j]) for each X[:, j]. max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j]. - + Parameters ---------- X: array, shape = len(a) x len(b) Input array. - + Returns ------- v: array, len(b) @@ -122,12 +122,12 @@ class Regularization(object): def Omega(T): """ Compute regularization term. - + Parameters ---------- T: array, shape = len(a) x len(b) Input array. - + Returns ------- value: float @@ -176,7 +176,7 @@ class SquaredL2(Regularization): def dual_obj_grad(alpha, beta, a, b, C, regul): """ Compute objective value and gradients of dual objective. - + Parameters ---------- alpha: array, shape = len(a) @@ -189,7 +189,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. - + Returns ------- obj: float @@ -220,7 +220,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): """ Solve the "smoothed" dual objective. - + Parameters ---------- a: array, shape = len(a) @@ -236,7 +236,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): Tolerance parameter. max_iter: int Maximum number of iterations. - + Returns ------- alpha: array, shape = len(a) @@ -275,7 +275,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): def semi_dual_obj_grad(alpha, a, b, C, regul): """ Compute objective value and gradient of semi-dual objective. - + Parameters ---------- alpha: array, shape = len(a) @@ -287,7 +287,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a max_Omega(X) method. - + Returns ------- obj: float @@ -314,7 +314,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): """ Solve the "smoothed" semi-dual objective. - + Parameters ---------- a: array, shape = len(a) @@ -330,7 +330,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): Tolerance parameter. max_iter: int Maximum number of iterations. - + Returns ------- alpha: array, shape = len(a) @@ -353,7 +353,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): def get_plan_from_dual(alpha, beta, C, regul): """ Retrieve optimal transportation plan from optimal dual potentials. - + Parameters ---------- alpha: array, shape = len(a) @@ -363,7 +363,7 @@ def get_plan_from_dual(alpha, beta, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. - + Returns ------- T: array, shape = len(a) x len(b) @@ -376,7 +376,7 @@ def get_plan_from_dual(alpha, beta, C, regul): def get_plan_from_semi_dual(alpha, b, C, regul): """ Retrieve optimal transportation plan from optimal semi-dual potentials. - + Parameters ---------- alpha: array, shape = len(a) @@ -387,7 +387,7 @@ def get_plan_from_semi_dual(alpha, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. - + Returns ------- T: array, shape = len(a) x len(b) @@ -399,20 +399,17 @@ def get_plan_from_semi_dual(alpha, b, C, regul): def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, numItermax=500, log=False): - - if reg_type.lower()=='l2': + + if reg_type.lower() == 'l2': regul = SquaredL2(gamma=reg) - elif reg_type.lower() in ['entropic','negentropy','kl']: - regul = NegEntropy(gamma=reg) - - alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,tol=stopThr) + elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + regul = NegEntropy(gamma=reg) + + alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr) G = get_plan_from_dual(alpha, beta, M, regul) - + if log: - log={'alpha':alpha,beta:'beta','res':res} + log = {'alpha': alpha, beta: 'beta', 'res': res} return G, log else: return G - - - -- cgit v1.2.3