summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/smooth.py37
-rw-r--r--test/test_smooth.py34
2 files changed, 67 insertions, 4 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index f4f4306..1a9972b 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -137,6 +137,7 @@ class Regularization(object):
class NegEntropy(Regularization):
+ """ NegEntropy regularization """
def delta_Omega(self, X):
G = np.exp(X / self.gamma - 1)
@@ -156,6 +157,7 @@ class NegEntropy(Regularization):
class SquaredL2(Regularization):
+ """ Squared L2 regularization """
def delta_Omega(self, X):
max_X = np.maximum(X, 0)
@@ -311,7 +313,8 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
return obj, grad
-def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
+def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ ):
"""
Solve the "smoothed" semi-dual objective.
@@ -400,16 +403,44 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
numItermax=500, log=False):
- if reg_type.lower() == 'l2':
+ if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+ # solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
+
+ # reconstruct transport matrix
G = get_plan_from_dual(alpha, beta, M, regul)
if log:
- log = {'alpha': alpha, beta: 'beta', 'res': res}
+ log = {'alpha': alpha, 'beta': beta, 'res': res}
+ return G, log
+ else:
+ return G
+
+
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, log=False):
+
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
+
+ # reconstruct transport matrix
+ G = get_plan_from_semi_dual(alpha, b, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'res': res}
return G, log
else:
return G
diff --git a/test/test_smooth.py b/test/test_smooth.py
index e95b3fe..37cc66e 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -20,7 +20,7 @@ def test_smooth_ot_dual():
M = ot.dist(x, x)
- Gl2 = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', stopThr=1e-10)
+ Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
# check constratints
np.testing.assert_allclose(
@@ -39,3 +39,35 @@ def test_smooth_ot_dual():
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
np.testing.assert_allclose(G, G2, atol=1e-05)
+
+
+def test_smooth_ot_semi_dual():
+
+ # get data
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ np.testing.assert_allclose(G, G2, atol=1e-05)