summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-31 12:02:13 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-31 12:02:13 +0200
commitb188381a8e8190958e4e9aca64b7501cf0321406 (patch)
tree3b43ee2e228e357569d478afaba0d99cca864638 /ot/smooth.py
parente585d64dd08e5367350e70f23e81f9fd2d676a6b (diff)
add semidual
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py37
1 files changed, 34 insertions, 3 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index f4f4306..1a9972b 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -137,6 +137,7 @@ class Regularization(object):
class NegEntropy(Regularization):
+ """ NegEntropy regularization """
def delta_Omega(self, X):
G = np.exp(X / self.gamma - 1)
@@ -156,6 +157,7 @@ class NegEntropy(Regularization):
class SquaredL2(Regularization):
+ """ Squared L2 regularization """
def delta_Omega(self, X):
max_X = np.maximum(X, 0)
@@ -311,7 +313,8 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
return obj, grad
-def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
+def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ ):
"""
Solve the "smoothed" semi-dual objective.
@@ -400,16 +403,44 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
numItermax=500, log=False):
- if reg_type.lower() == 'l2':
+ if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+ # solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
+
+ # reconstruct transport matrix
G = get_plan_from_dual(alpha, beta, M, regul)
if log:
- log = {'alpha': alpha, beta: 'beta', 'res': res}
+ log = {'alpha': alpha, 'beta': beta, 'res': res}
+ return G, log
+ else:
+ return G
+
+
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, log=False):
+
+ if reg_type.lower() in ['l2', 'squaredl2']:
+ regul = SquaredL2(gamma=reg)
+ elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
+ regul = NegEntropy(gamma=reg)
+ else:
+ raise NotImplementedError('Unknown regularization')
+
+ # solve dual
+ alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
+
+ # reconstruct transport matrix
+ G = get_plan_from_semi_dual(alpha, b, M, regul)
+
+ if log:
+ log = {'alpha': alpha, 'res': res}
return G, log
else:
return G