summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-31 13:36:12 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-31 13:36:12 +0200
commitfb883fcddff31cc4e21c921ebeb9a550eaa48ff0 (patch)
tree0b1fe9e43ae130135cdad6c82797da67acfa9744 /ot/smooth.py
parent724984d3e13aa9390ef17a389fbd6ccf1f277d69 (diff)
proper documentation
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py155
1 files changed, 147 insertions, 8 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index 237b67a..4c46e73 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -226,7 +226,8 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
return obj, grad_alpha, grad_beta
-def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
+def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
+ verbose=False):
"""
Solve the "smoothed" dual objective.
@@ -273,7 +274,7 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
params_init = np.concatenate((alpha_init, beta_init))
res = minimize(_func, params_init, method=method, jac=True,
- tol=tol, options=dict(maxiter=max_iter, disp=False))
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
alpha = res.x[:len(a)]
beta = res.x[len(a):]
@@ -321,7 +322,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
- ):
+ verbose=False):
"""
Solve the "smoothed" semi-dual objective.
@@ -355,7 +356,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
alpha_init = np.zeros(len(a))
res = minimize(_func, alpha_init, method=method, jac=True,
- tol=tol, options=dict(maxiter=max_iter, disp=False))
+ tol=tol, options=dict(maxiter=max_iter, disp=verbose))
return res.x, res
@@ -408,7 +409,75 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
- numItermax=500, log=False):
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the dual and return the OT matrix
+
+ The function solves the smooth relaxed dual formulation (7) in [17]_ :
+
+ .. math::
+ \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+
+ where :
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
+ (See [17]_ Proposition 1).
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
@@ -418,7 +487,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
raise NotImplementedError('Unknown regularization')
# solve dual
- alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
+ alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
+ tol=stopThr, verbose=verbose)
# reconstruct transport matrix
G = get_plan_from_dual(alpha, beta, M, regul)
@@ -431,8 +501,77 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
- numItermax=500, log=False):
+ numItermax=500, verbose=False, log=False):
+ r"""
+ Solve the regularized OT problem in the semi-dual and return the OT matrix
+ The function solves the smooth relaxed dual formulation (10) in [17]_ :
+
+ .. math::
+ \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+
+ where :
+
+ .. math::
+ OT_\Omega^*(\alpha,b)=\sum_j b_j
+
+ - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
+ - a and b are source and target weights (sum to 1)
+
+ The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The optimization algorithm is using gradient decent (L-BFGS by default).
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ reg_type : str
+ Regularization type, can be the following (default ='l2'):
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
+ - 'l2' : Squared Euclidean regularization
+ method : str
+ Solver to use for scipy.optimize.minimize
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.sinhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -444,7 +583,7 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
# reconstruct transport matrix
- G = get_plan_from_semi_dual(alpha, b, M, regul)
+ G = get_plan_from_semi_dual(alpha, b, M, regul, verbose=verbose)
if log:
log = {'alpha': alpha, 'res': res}