diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-07 08:51:45 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-07 08:51:45 +0200 |
commit | 47477c5782f87570de590ea423a082b71dd63241 (patch) | |
tree | 00575e528509242cdad2cd283d5e2e0c44e53349 | |
parent | 0fc1124e001932354ec3d229d198cba166cd0b0e (diff) |
add sinkhorbn2 +v3
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | examples/plot_compute_emd.py | 4 | ||||
-rw-r--r-- | ot/__init__.py | 8 | ||||
-rw-r--r-- | ot/bregman.py | 144 |
4 files changed, 133 insertions, 28 deletions
@@ -83,14 +83,15 @@ import ot # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix Wd=ot.emd2(a,b,M) # exact linear program +Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT # if b is a matrix compute all distances to a and return a vector ``` * Compute OT matrix ```python # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix -Totp=ot.emd(a,b,M) # exact linear program -Totp_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT +T=ot.emd(a,b,M) # exact linear program +T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT ``` * Compute Wasserstein barycenter ```python diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index c7063e8..f2cdc35 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -61,8 +61,8 @@ pl.legend() #%% reg=1e-2 -d_sinkhorn=ot.sinkhorn(a,B,M,reg) -d_sinkhorn2=ot.sinkhorn(a,B,M2,reg) +d_sinkhorn=ot.sinkhorn2(a,B,M,reg) +d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg) pl.figure(2) pl.clf() diff --git a/ot/__init__.py b/ot/__init__.py index b2af88b..4220148 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -16,14 +16,14 @@ from . import da # OT functions from .lp import emd, emd2 -from .bregman import sinkhorn, barycenter +from .bregman import sinkhorn, sinkhorn2, barycenter from .da import sinkhorn_lpl1_mm # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.2" +__version__ = "0.3" -__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp', - 'plot', 'tic', 'toc', 'toq', +__all__ = ["emd", "emd2", "sinkhorn","sinkhorn2", "utils", 'datasets', + 'bregman', 'lp', 'plot', 'tic', 'toc', 'toq', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/bregman.py b/ot/bregman.py index 68be01c..0d68602 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -41,7 +41,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver Regularization term >0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_epsilon_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -91,7 +91,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] """ - + if method.lower()=='sinkhorn': sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log,**kwargs) @@ -100,15 +100,119 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower()=='sinkhorn_epsilon_scaling': sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: print('Warning : unknown method using classic Sinkhorn Knopp') sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs) - + return sink() + +def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + u""" + Solve the entropic regularization optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + W : (nt) ndarray or float + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn2(a,b,M,1) + array([ 0.26894142]) + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + + """ + + if method.lower()=='sinkhorn': + sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log,**kwargs) + elif method.lower()=='sinkhorn_stabilized': + sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower()=='sinkhorn_epsilon_scaling': + sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + else: + print('Warning : unknown method using classic Sinkhorn Knopp') + sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs) + + b=np.asarray(b,dtype=np.float64) + if len(b.shape)<2: + b=b.reshape((-1,1)) + + return sink() + def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): """ @@ -189,23 +293,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, a=np.asarray(a,dtype=np.float64) b=np.asarray(b,dtype=np.float64) M=np.asarray(M,dtype=np.float64) - + if len(a)==0: a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] if len(b)==0: b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] - + # init data Nini = len(a) Nfin = len(b) - + if len(b.shape)>1: nbb=b.shape[1] else: nbb=0 - + if log: log={'err':[]} @@ -217,7 +321,7 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, else: u = np.ones(Nini)/Nini v = np.ones(Nfin)/Nfin - + #print(reg) @@ -261,23 +365,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, if log: log['u']=u log['v']=v - - if nbb: #return only loss + + if nbb: #return only loss res=np.zeros((nbb)) for i in range(nbb): res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) if log: return res,log else: - return res - + return res + else: # return OT matrix - + if log: return u.reshape((-1,1))*K*v.reshape((1,-1)),log else: return u.reshape((-1,1))*K*v.reshape((1,-1)) - + def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs): """ @@ -393,7 +497,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war alpha,beta=np.zeros(na),np.zeros(nb) else: alpha,beta=warmstart - + if nbb: u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb else: @@ -420,7 +524,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war uprev = u vprev = v - + # sinkhorn update v = b/(np.dot(K.T,u)+1e-16) u = a/(np.dot(K,v)+1e-16) @@ -471,8 +575,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war break cpt = cpt +1 - - + + #print('err=',err,' cpt=',cpt) if log: log['logu']=alpha/reg+np.log(u) @@ -493,7 +597,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war res=np.zeros((nbb)) for i in range(nbb): res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) - return res + return res else: return get_Gamma(alpha,beta,u,v) |