From a632c40b97c51534bfe65fee8b493780df118039 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 13 Jun 2017 14:39:59 +0200 Subject: make sinkhorn more general with method selection --- ot/bregman.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 3 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 00dca88..6b3c68b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,8 +5,112 @@ Bregman projections for regularized OT import numpy as np +def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + u""" + Solve the entropic regularization optimal transport problem + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + + """ + + if method.lower()=='sinkhorn': + sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log,**kwargs) + elif method.lower()=='sinkhorn_stabilized': + sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower()=='sinkhorn_epsilon_scaling': + sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + else: + print('Warning : unknown method using classic Sinkhorn Knopp') + sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs) + + return sink() + + + + -def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False): +def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): """ Solve the entropic regularization optimal transport problem @@ -147,7 +251,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa else: return u.reshape((-1,1))*K*v.reshape((1,-1)) -def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False): +def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs): """ Solve the entropic regularization OT problem with log stabilization @@ -331,7 +435,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war else: return get_Gamma(alpha,beta,u,v) -def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False): +def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs): """ Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. -- cgit v1.2.3