From e585d64dd08e5367350e70f23e81f9fd2d676a6b Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 31 May 2018 11:33:08 +0200 Subject: pep8 --- ot/smooth.py | 67 +++++++++++++++++++++++++---------------------------- test/test_smooth.py | 26 ++++++++++++--------- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/ot/smooth.py b/ot/smooth.py index f8bb20a..f4f4306 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -1,4 +1,4 @@ -#Copyright (c) 2018, Mathieu Blondel +#Copyright (c) 2018, Mathieu Blondel #All rights reserved. # #Redistribution and use in source and binary forms, with or without @@ -22,7 +22,7 @@ #OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF #THE POSSIBILITY OF SUCH DAMAGE. -# Author: Mathieu Blondel +# Author: Mathieu Blondel # Remi Flamary """ @@ -39,7 +39,7 @@ from scipy.optimize import minimize def projection_simplex(V, z=1, axis=None): """ Projection of x onto the simplex, scaled by z - + P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2 z: float or array If array, len(z) must be compatible with V @@ -65,19 +65,19 @@ def projection_simplex(V, z=1, axis=None): else: V = V.ravel().reshape(1, -1) return projection_simplex(V, z, axis=1).ravel() - + class Regularization(object): def __init__(self, gamma=1.0): """ - + Parameters ---------- gamma: float Regularization parameter. We recover unregularized OT when gamma -> 0. - + """ self.gamma = gamma @@ -85,12 +85,12 @@ class Regularization(object): """ Compute delta_Omega(X[:, j]) for each X[:, j]. delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y). - + Parameters ---------- X: array, shape = len(a) x len(b) Input array. - + Returns ------- v: array, len(b) @@ -104,12 +104,12 @@ class Regularization(object): """ Compute max_Omega_j(X[:, j]) for each X[:, j]. max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j]. - + Parameters ---------- X: array, shape = len(a) x len(b) Input array. - + Returns ------- v: array, len(b) @@ -122,12 +122,12 @@ class Regularization(object): def Omega(T): """ Compute regularization term. - + Parameters ---------- T: array, shape = len(a) x len(b) Input array. - + Returns ------- value: float @@ -176,7 +176,7 @@ class SquaredL2(Regularization): def dual_obj_grad(alpha, beta, a, b, C, regul): """ Compute objective value and gradients of dual objective. - + Parameters ---------- alpha: array, shape = len(a) @@ -189,7 +189,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. - + Returns ------- obj: float @@ -220,7 +220,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): """ Solve the "smoothed" dual objective. - + Parameters ---------- a: array, shape = len(a) @@ -236,7 +236,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): Tolerance parameter. max_iter: int Maximum number of iterations. - + Returns ------- alpha: array, shape = len(a) @@ -275,7 +275,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): def semi_dual_obj_grad(alpha, a, b, C, regul): """ Compute objective value and gradient of semi-dual objective. - + Parameters ---------- alpha: array, shape = len(a) @@ -287,7 +287,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a max_Omega(X) method. - + Returns ------- obj: float @@ -314,7 +314,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): """ Solve the "smoothed" semi-dual objective. - + Parameters ---------- a: array, shape = len(a) @@ -330,7 +330,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): Tolerance parameter. max_iter: int Maximum number of iterations. - + Returns ------- alpha: array, shape = len(a) @@ -353,7 +353,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): def get_plan_from_dual(alpha, beta, C, regul): """ Retrieve optimal transportation plan from optimal dual potentials. - + Parameters ---------- alpha: array, shape = len(a) @@ -363,7 +363,7 @@ def get_plan_from_dual(alpha, beta, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. - + Returns ------- T: array, shape = len(a) x len(b) @@ -376,7 +376,7 @@ def get_plan_from_dual(alpha, beta, C, regul): def get_plan_from_semi_dual(alpha, b, C, regul): """ Retrieve optimal transportation plan from optimal semi-dual potentials. - + Parameters ---------- alpha: array, shape = len(a) @@ -387,7 +387,7 @@ def get_plan_from_semi_dual(alpha, b, C, regul): Ground cost matrix. regul: Regularization object Should implement a delta_Omega(X) method. - + Returns ------- T: array, shape = len(a) x len(b) @@ -399,20 +399,17 @@ def get_plan_from_semi_dual(alpha, b, C, regul): def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, numItermax=500, log=False): - - if reg_type.lower()=='l2': + + if reg_type.lower() == 'l2': regul = SquaredL2(gamma=reg) - elif reg_type.lower() in ['entropic','negentropy','kl']: - regul = NegEntropy(gamma=reg) - - alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,tol=stopThr) + elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + regul = NegEntropy(gamma=reg) + + alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr) G = get_plan_from_dual(alpha, beta, M, regul) - + if log: - log={'alpha':alpha,beta:'beta','res':res} + log = {'alpha': alpha, beta: 'beta', 'res': res} return G, log else: return G - - - diff --git a/test/test_smooth.py b/test/test_smooth.py index 4ca44f8..e95b3fe 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -4,17 +4,14 @@ # # License: MIT License -import warnings - import numpy as np import ot -from ot.datasets import get_1D_gauss as gauss -import pytest def test_smooth_ot_dual(): - # test sinkhorn + + # get data n = 100 rng = np.random.RandomState(0) @@ -23,15 +20,22 @@ def test_smooth_ot_dual(): M = ot.dist(x, x) - G = ot.smooth.smooth_ot_dual(u, u, M, 1, stopThr=1e-10) + Gl2 = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', stopThr=1e-10) + + # check constratints + np.testing.assert_allclose( + u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn + + # kl regyularisation + G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) # check constratints np.testing.assert_allclose( u, G.sum(1), atol=1e-05) # cf convergence sinkhorn np.testing.assert_allclose( - u, G.sum(0), atol=1e-05) # cf convergence sinkhorn - - + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) - np.testing.assert_allclose( G, G2 , atol=1e-05) - \ No newline at end of file + np.testing.assert_allclose(G, G2, atol=1e-05) -- cgit v1.2.3