diff options
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | examples/demo_OT_1D.py | 4 | ||||
-rw-r--r-- | examples/demo_optim_OTreg.py | 10 | ||||
-rw-r--r-- | ot/bregman.py | 349 | ||||
-rw-r--r-- | ot/datasets.py | 2 |
5 files changed, 362 insertions, 9 deletions
@@ -8,7 +8,7 @@ This open source Python library provide several solvers for optimization problem It provides the following solvers: * OT solver for the linear program/ Earth Movers Distance [1]. -* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]. +* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10]. * Bregman projections for Wasserstein barycenter [3] and unmixing [4]. * Optimal transport for domain adaptation with group lasso regularization [5] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. @@ -98,3 +98,7 @@ This toolbox benefit a lot from open source research and we would like to thank [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. + +[9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 6eaa2ff..df65a60 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -19,8 +19,8 @@ n=100 # nb bins x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=gauss(n,m=20,s=20) # m= mean, s= std -b=gauss(n,m=60,s=60) +a=gauss(n,m=20,s=5) # m= mean, s= std +b=gauss(n,m=60,s=10) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) diff --git a/examples/demo_optim_OTreg.py b/examples/demo_optim_OTreg.py index 5e19be5..0a8c583 100644 --- a/examples/demo_optim_OTreg.py +++ b/examples/demo_optim_OTreg.py @@ -17,8 +17,8 @@ n=100 # nb bins x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std -b=ot.datasets.get_1D_gauss(n,m=60,s=60) +a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std +b=ot.datasets.get_1D_gauss(n,m=60,s=10) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) @@ -37,7 +37,7 @@ def f(G): return 0.5*np.sum(G**2) def df(G): return G reg=1e-1 - + Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True) pl.figure(3) @@ -47,9 +47,9 @@ ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg') def f(G): return np.sum(G*np.log(G)) def df(G): return np.log(G)+1 - + reg=1e-3 - + Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True) pl.figure(4) diff --git a/ot/bregman.py b/ot/bregman.py index a770c5a..b132225 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -144,6 +144,355 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa else: return np.dot(np.diag(u),np.dot(K,np.diag(v))) +def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False): + """ + Solve the entropic regularization OT problem with log stabilization + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ but with the log stabilization + proposed in [10]_ an defined in [9]_ (Algo 3.1) . + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + tau : float + thershold for max value in u or v for log scaling + warmstart : tible of vectors + if given then sarting values for alpha an beta log scalings + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a=np.asarray(a,dtype=np.float64) + b=np.asarray(b,dtype=np.float64) + M=np.asarray(M,dtype=np.float64) + + if len(a)==0: + a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] + if len(b)==0: + b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + + # init data + na = len(a) + nb = len(b) + + + cpt = 0 + if log: + log={'err':[]} + + # we assume that no distances are null except those of the diagonal of distances + if warmstart is None: + alpha,beta=np.zeros(na),np.zeros(nb) + else: + alpha,beta=warmstart + u,v = np.ones(na)/na,np.ones(nb)/nb + uprev,vprev=np.zeros(na),np.zeros(nb) + + + #print reg + + + def get_K(alpha,beta): + """log space computation""" + return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg) + + def get_Gamma(alpha,beta,u,v): + """log space gamma computation""" + return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb)))) + + #print np.min(K) + + K=get_K(alpha,beta) + transp = K + loop=1 + cpt = 0 + err=1 + while loop: + + if np.abs(u).max()>tau or np.abs(v).max()>tau: + alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v) + u,v = np.ones(na)/na,np.ones(nb)/nb + K=get_K(alpha,beta) + + uprev = u + vprev = v + v = b/np.dot(K.T,u) + u = a/np.dot(K,v) + + + + if cpt%print_period==0: + # we can speed up the process by checking for the error only all the 10th iterations + transp = get_Gamma(alpha,beta,u,v) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + if log: + log['err'].append(err) + + if verbose: + if cpt%(print_period*20) ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) + + + if err<=stopThr: + loop=False + + if cpt>=numItermax: + loop=False + + + if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errrors') + if cpt!=0: + u = uprev + v = vprev + break + + cpt = cpt +1 + #print 'err=',err,' cpt=',cpt + if log: + log['logu']=alpha/reg+np.log(u) + log['logv']=beta/reg+np.log(v) + log['alpha']=alpha+reg*np.log(u) + log['beta']=beta+reg*np.log(v) + log['warmstart']=(log['alpha'],log['beta']) + return get_Gamma(alpha,beta,u,v),log + else: + return get_Gamma(alpha,beta,u,v) + +def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False): + """ + Solve the entropic regularization optimal transport problem with log + stabilization and epsilon scaling. + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ but with the log stabilization + proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2 + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + tau : float + thershold for max value in u or v for log scaling + tau : float + thershold for max value in u or v for log scaling + warmstart : tible of vectors + if given then sarting values for alpha an beta log scalings + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterationsin the inner slog stabilized sinkhorn + epsilon0 : int, optional + first epsilon regularization value (then exponential decrease to reg) + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a=np.asarray(a,dtype=np.float64) + b=np.asarray(b,dtype=np.float64) + M=np.asarray(M,dtype=np.float64) + + if len(a)==0: + a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] + if len(b)==0: + b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + + # init data + na = len(a) + nb = len(b) + + # nrelative umerical precision with 64 bits + numItermin = 35 + numItermax=max(numItermin,numItermax) # ensure that last velue is exact + + + cpt = 0 + if log: + log={'err':[]} + + # we assume that no distances are null except those of the diagonal of distances + if warmstart is None: + alpha,beta=np.zeros(na),np.zeros(nb) + else: + alpha,beta=warmstart + + + def get_K(alpha,beta): + """log space computation""" + return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg) + + #print np.min(K) + def get_reg(n): # exponential decreasing + return (epsilon0-reg)*np.exp(-n)+reg + + loop=1 + cpt = 0 + err=1 + while loop: + + regi=get_reg(cpt) + + G,logi=sinkhorn_stabilized(a,b, M, regi, numItermax = numInnerItermax,tau=1e3, stopThr=1e-9,warmstart=(alpha,beta), verbose=False,print_period=20,tau=tau, log=True) + + alpha=logi['alpha'] + beta=logi['beta'] + + if cpt>=numItermax: + loop=False + + if cpt%(print_period)==0: # spsion nearly converged + # we can speed up the process by checking for the error only all the 10th iterations + transp = G + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2+np.linalg.norm((np.sum(transp,axis=1)-a))**2 + if log: + log['err'].append(err) + + if verbose: + if cpt%(print_period*10) ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) + + if err<=stopThr and cpt>numItermin: + loop=False + + cpt = cpt +1 + #print 'err=',err,' cpt=',cpt + if log: + log['alpha']=alpha + log['beta']=beta + log['warmstart']=(log['alpha'],log['beta']) + return G,log + else: + return G + def geometricBar(weights,alldistribT): """return the weighted geometric mean of distributions""" diff --git a/ot/datasets.py b/ot/datasets.py index c750812..8605691 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -28,7 +28,7 @@ def get_1D_gauss(n,m,s): """ x=np.arange(n,dtype=np.float64) - h=np.exp(-(x-m)**2/(2*s^2)) + h=np.exp(-(x-m)**2/(2*s**2)) return h/h.sum() |