diff options
-rw-r--r-- | ot/smooth.py | 155 |
1 files changed, 147 insertions, 8 deletions
diff --git a/ot/smooth.py b/ot/smooth.py index 237b67a..4c46e73 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -226,7 +226,8 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): return obj, grad_alpha, grad_beta -def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): +def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, + verbose=False): """ Solve the "smoothed" dual objective. @@ -273,7 +274,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500): params_init = np.concatenate((alpha_init, beta_init)) res = minimize(_func, params_init, method=method, jac=True, - tol=tol, options=dict(maxiter=max_iter, disp=False)) + tol=tol, options=dict(maxiter=max_iter, disp=verbose)) alpha = res.x[:len(a)] beta = res.x[len(a):] @@ -321,7 +322,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, - ): + verbose=False): """ Solve the "smoothed" semi-dual objective. @@ -355,7 +356,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, alpha_init = np.zeros(len(a)) res = minimize(_func, alpha_init, method=method, jac=True, - tol=tol, options=dict(maxiter=max_iter, disp=False)) + tol=tol, options=dict(maxiter=max_iter, disp=verbose)) return res.x, res @@ -408,7 +409,75 @@ def get_plan_from_semi_dual(alpha, b, C, regul): def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, - numItermax=500, log=False): + numItermax=500, verbose=False, log=False): + r""" + Solve the regularized OT problem in the dual and return the OT matrix + + The function solves the smooth relaxed dual formulation (7) in [17]_ : + + .. math:: + \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j) + + where : + + - :math:`\mathbf{m}_j` is the jth column of the cost matrix + - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega` + - a and b are source and target weights (sum to 1) + + The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega` + (See [17]_ Proposition 1). + The optimization algorithm is using gradient decent (L-BFGS by default). + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + reg_type : str + Regularization type, can be the following (default ='l2'): + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) + - 'l2' : Squared Euclidean regularization + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ if reg_type.lower() in ['l2', 'squaredl2']: regul = SquaredL2(gamma=reg) @@ -418,7 +487,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, raise NotImplementedError('Unknown regularization') # solve dual - alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr) + alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, + tol=stopThr, verbose=verbose) # reconstruct transport matrix G = get_plan_from_dual(alpha, beta, M, regul) @@ -431,8 +501,77 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, - numItermax=500, log=False): + numItermax=500, verbose=False, log=False): + r""" + Solve the regularized OT problem in the semi-dual and return the OT matrix + The function solves the smooth relaxed dual formulation (10) in [17]_ : + + .. math:: + \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b) + + where : + + .. math:: + OT_\Omega^*(\alpha,b)=\sum_j b_j + + - :math:`\mathbf{m}_j` is the jth column of the cost matrix + - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17] + - a and b are source and target weights (sum to 1) + + The OT matrix can is reconstructed using [17]_ Proposition 2. + The optimization algorithm is using gradient decent (L-BFGS by default). + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + reg_type : str + Regularization type, can be the following (default ='l2'): + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) + - 'l2' : Squared Euclidean regularization + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ if reg_type.lower() in ['l2', 'squaredl2']: regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: @@ -444,7 +583,7 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr) # reconstruct transport matrix - G = get_plan_from_semi_dual(alpha, b, M, regul) + G = get_plan_from_semi_dual(alpha, b, M, regul, verbose=verbose) if log: log = {'alpha': alpha, 'res': res} |