From b188381a8e8190958e4e9aca64b7501cf0321406 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 31 May 2018 12:02:13 +0200 Subject: add semidual --- ot/smooth.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) (limited to 'ot/smooth.py') diff --git a/ot/smooth.py b/ot/smooth.py index f4f4306..1a9972b 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -137,6 +137,7 @@ class Regularization(object): class NegEntropy(Regularization): + """ NegEntropy regularization """ def delta_Omega(self, X): G = np.exp(X / self.gamma - 1) @@ -156,6 +157,7 @@ class NegEntropy(Regularization): class SquaredL2(Regularization): + """ Squared L2 regularization """ def delta_Omega(self, X): max_X = np.maximum(X, 0) @@ -311,7 +313,8 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): return obj, grad -def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): +def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, + ): """ Solve the "smoothed" semi-dual objective. @@ -400,16 +403,44 @@ def get_plan_from_semi_dual(alpha, b, C, regul): def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, numItermax=500, log=False): - if reg_type.lower() == 'l2': + if reg_type.lower() in ['l2', 'squaredl2']: regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + else: + raise NotImplementedError('Unknown regularization') + # solve dual alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr) + + # reconstruct transport matrix G = get_plan_from_dual(alpha, beta, M, regul) if log: - log = {'alpha': alpha, beta: 'beta', 'res': res} + log = {'alpha': alpha, 'beta': beta, 'res': res} + return G, log + else: + return G + + +def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, + numItermax=500, log=False): + + if reg_type.lower() in ['l2', 'squaredl2']: + regul = SquaredL2(gamma=reg) + elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + regul = NegEntropy(gamma=reg) + else: + raise NotImplementedError('Unknown regularization') + + # solve dual + alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr) + + # reconstruct transport matrix + G = get_plan_from_semi_dual(alpha, b, M, regul) + + if log: + log = {'alpha': alpha, 'res': res} return G, log else: return G -- cgit v1.2.3