summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-31 11:33:08 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-31 11:33:08 +0200
commite585d64dd08e5367350e70f23e81f9fd2d676a6b (patch)
tree7eb65a7dd80921ffedbe215b6fa911cbff4e796b /ot/smooth.py
parentaf99d3ff66062811a81454ea03e6b831a1292ae4 (diff)
pep8
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py67
1 files changed, 32 insertions, 35 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index f8bb20a..f4f4306 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -1,4 +1,4 @@
-#Copyright (c) 2018, Mathieu Blondel
+#Copyright (c) 2018, Mathieu Blondel
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
@@ -22,7 +22,7 @@
#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
#THE POSSIBILITY OF SUCH DAMAGE.
-# Author: Mathieu Blondel
+# Author: Mathieu Blondel
# Remi Flamary <remi.flamary@unice.fr>
"""
@@ -39,7 +39,7 @@ from scipy.optimize import minimize
def projection_simplex(V, z=1, axis=None):
""" Projection of x onto the simplex, scaled by z
-
+
P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
z: float or array
If array, len(z) must be compatible with V
@@ -65,19 +65,19 @@ def projection_simplex(V, z=1, axis=None):
else:
V = V.ravel().reshape(1, -1)
return projection_simplex(V, z, axis=1).ravel()
-
+
class Regularization(object):
def __init__(self, gamma=1.0):
"""
-
+
Parameters
----------
gamma: float
Regularization parameter.
We recover unregularized OT when gamma -> 0.
-
+
"""
self.gamma = gamma
@@ -85,12 +85,12 @@ class Regularization(object):
"""
Compute delta_Omega(X[:, j]) for each X[:, j].
delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
-
+
Parameters
----------
X: array, shape = len(a) x len(b)
Input array.
-
+
Returns
-------
v: array, len(b)
@@ -104,12 +104,12 @@ class Regularization(object):
"""
Compute max_Omega_j(X[:, j]) for each X[:, j].
max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
-
+
Parameters
----------
X: array, shape = len(a) x len(b)
Input array.
-
+
Returns
-------
v: array, len(b)
@@ -122,12 +122,12 @@ class Regularization(object):
def Omega(T):
"""
Compute regularization term.
-
+
Parameters
----------
T: array, shape = len(a) x len(b)
Input array.
-
+
Returns
-------
value: float
@@ -176,7 +176,7 @@ class SquaredL2(Regularization):
def dual_obj_grad(alpha, beta, a, b, C, regul):
"""
Compute objective value and gradients of dual objective.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -189,7 +189,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
-
+
Returns
-------
obj: float
@@ -220,7 +220,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
"""
Solve the "smoothed" dual objective.
-
+
Parameters
----------
a: array, shape = len(a)
@@ -236,7 +236,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
Tolerance parameter.
max_iter: int
Maximum number of iterations.
-
+
Returns
-------
alpha: array, shape = len(a)
@@ -275,7 +275,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
def semi_dual_obj_grad(alpha, a, b, C, regul):
"""
Compute objective value and gradient of semi-dual objective.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -287,7 +287,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a max_Omega(X) method.
-
+
Returns
-------
obj: float
@@ -314,7 +314,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
"""
Solve the "smoothed" semi-dual objective.
-
+
Parameters
----------
a: array, shape = len(a)
@@ -330,7 +330,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
Tolerance parameter.
max_iter: int
Maximum number of iterations.
-
+
Returns
-------
alpha: array, shape = len(a)
@@ -353,7 +353,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
def get_plan_from_dual(alpha, beta, C, regul):
"""
Retrieve optimal transportation plan from optimal dual potentials.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -363,7 +363,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
-
+
Returns
-------
T: array, shape = len(a) x len(b)
@@ -376,7 +376,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
def get_plan_from_semi_dual(alpha, b, C, regul):
"""
Retrieve optimal transportation plan from optimal semi-dual potentials.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -387,7 +387,7 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
-
+
Returns
-------
T: array, shape = len(a) x len(b)
@@ -399,20 +399,17 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
numItermax=500, log=False):
-
- if reg_type.lower()=='l2':
+
+ if reg_type.lower() == 'l2':
regul = SquaredL2(gamma=reg)
- elif reg_type.lower() in ['entropic','negentropy','kl']:
- regul = NegEntropy(gamma=reg)
-
- alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,tol=stopThr)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+
+ alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
G = get_plan_from_dual(alpha, beta, M, regul)
-
+
if log:
- log={'alpha':alpha,beta:'beta','res':res}
+ log = {'alpha': alpha, beta: 'beta', 'res': res}
return G, log
else:
return G
-
-
-