summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/smooth.py67
-rw-r--r--test/test_smooth.py26
2 files changed, 47 insertions, 46 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index f8bb20a..f4f4306 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -1,4 +1,4 @@
-#Copyright (c) 2018, Mathieu Blondel
+#Copyright (c) 2018, Mathieu Blondel
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
@@ -22,7 +22,7 @@
#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
#THE POSSIBILITY OF SUCH DAMAGE.
-# Author: Mathieu Blondel
+# Author: Mathieu Blondel
# Remi Flamary <remi.flamary@unice.fr>
"""
@@ -39,7 +39,7 @@ from scipy.optimize import minimize
def projection_simplex(V, z=1, axis=None):
""" Projection of x onto the simplex, scaled by z
-
+
P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
z: float or array
If array, len(z) must be compatible with V
@@ -65,19 +65,19 @@ def projection_simplex(V, z=1, axis=None):
else:
V = V.ravel().reshape(1, -1)
return projection_simplex(V, z, axis=1).ravel()
-
+
class Regularization(object):
def __init__(self, gamma=1.0):
"""
-
+
Parameters
----------
gamma: float
Regularization parameter.
We recover unregularized OT when gamma -> 0.
-
+
"""
self.gamma = gamma
@@ -85,12 +85,12 @@ class Regularization(object):
"""
Compute delta_Omega(X[:, j]) for each X[:, j].
delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
-
+
Parameters
----------
X: array, shape = len(a) x len(b)
Input array.
-
+
Returns
-------
v: array, len(b)
@@ -104,12 +104,12 @@ class Regularization(object):
"""
Compute max_Omega_j(X[:, j]) for each X[:, j].
max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
-
+
Parameters
----------
X: array, shape = len(a) x len(b)
Input array.
-
+
Returns
-------
v: array, len(b)
@@ -122,12 +122,12 @@ class Regularization(object):
def Omega(T):
"""
Compute regularization term.
-
+
Parameters
----------
T: array, shape = len(a) x len(b)
Input array.
-
+
Returns
-------
value: float
@@ -176,7 +176,7 @@ class SquaredL2(Regularization):
def dual_obj_grad(alpha, beta, a, b, C, regul):
"""
Compute objective value and gradients of dual objective.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -189,7 +189,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
-
+
Returns
-------
obj: float
@@ -220,7 +220,7 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
"""
Solve the "smoothed" dual objective.
-
+
Parameters
----------
a: array, shape = len(a)
@@ -236,7 +236,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
Tolerance parameter.
max_iter: int
Maximum number of iterations.
-
+
Returns
-------
alpha: array, shape = len(a)
@@ -275,7 +275,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
def semi_dual_obj_grad(alpha, a, b, C, regul):
"""
Compute objective value and gradient of semi-dual objective.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -287,7 +287,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a max_Omega(X) method.
-
+
Returns
-------
obj: float
@@ -314,7 +314,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
"""
Solve the "smoothed" semi-dual objective.
-
+
Parameters
----------
a: array, shape = len(a)
@@ -330,7 +330,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
Tolerance parameter.
max_iter: int
Maximum number of iterations.
-
+
Returns
-------
alpha: array, shape = len(a)
@@ -353,7 +353,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
def get_plan_from_dual(alpha, beta, C, regul):
"""
Retrieve optimal transportation plan from optimal dual potentials.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -363,7 +363,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
-
+
Returns
-------
T: array, shape = len(a) x len(b)
@@ -376,7 +376,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
def get_plan_from_semi_dual(alpha, b, C, regul):
"""
Retrieve optimal transportation plan from optimal semi-dual potentials.
-
+
Parameters
----------
alpha: array, shape = len(a)
@@ -387,7 +387,7 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
Ground cost matrix.
regul: Regularization object
Should implement a delta_Omega(X) method.
-
+
Returns
-------
T: array, shape = len(a) x len(b)
@@ -399,20 +399,17 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
numItermax=500, log=False):
-
- if reg_type.lower()=='l2':
+
+ if reg_type.lower() == 'l2':
regul = SquaredL2(gamma=reg)
- elif reg_type.lower() in ['entropic','negentropy','kl']:
- regul = NegEntropy(gamma=reg)
-
- alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,tol=stopThr)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+
+ alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
G = get_plan_from_dual(alpha, beta, M, regul)
-
+
if log:
- log={'alpha':alpha,beta:'beta','res':res}
+ log = {'alpha': alpha, beta: 'beta', 'res': res}
return G, log
else:
return G
-
-
-
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 4ca44f8..e95b3fe 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -4,17 +4,14 @@
#
# License: MIT License
-import warnings
-
import numpy as np
import ot
-from ot.datasets import get_1D_gauss as gauss
-import pytest
def test_smooth_ot_dual():
- # test sinkhorn
+
+ # get data
n = 100
rng = np.random.RandomState(0)
@@ -23,15 +20,22 @@ def test_smooth_ot_dual():
M = ot.dist(x, x)
- G = ot.smooth.smooth_ot_dual(u, u, M, 1, stopThr=1e-10)
+ Gl2 = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
# check constratints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
-
-
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
- np.testing.assert_allclose( G, G2 , atol=1e-05)
- \ No newline at end of file
+ np.testing.assert_allclose(G, G2, atol=1e-05)