summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md6
-rw-r--r--examples/demo_OT_1D.py4
-rw-r--r--examples/demo_optim_OTreg.py10
-rw-r--r--ot/bregman.py349
-rw-r--r--ot/datasets.py2
5 files changed, 362 insertions, 9 deletions
diff --git a/README.md b/README.md
index 4088bc2..c6d480b 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:
* OT solver for the linear program/ Earth Movers Distance [1].
-* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2].
+* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10].
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -98,3 +98,7 @@ This toolbox benefit a lot from open source research and we would like to thank
[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+
+[9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
index 6eaa2ff..df65a60 100644
--- a/examples/demo_OT_1D.py
+++ b/examples/demo_OT_1D.py
@@ -19,8 +19,8 @@ n=100 # nb bins
x=np.arange(n,dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=20) # m= mean, s= std
-b=gauss(n,m=60,s=60)
+a=gauss(n,m=20,s=5) # m= mean, s= std
+b=gauss(n,m=60,s=10)
# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
diff --git a/examples/demo_optim_OTreg.py b/examples/demo_optim_OTreg.py
index 5e19be5..0a8c583 100644
--- a/examples/demo_optim_OTreg.py
+++ b/examples/demo_optim_OTreg.py
@@ -17,8 +17,8 @@ n=100 # nb bins
x=np.arange(n,dtype=np.float64)
# Gaussian distributions
-a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
-b=ot.datasets.get_1D_gauss(n,m=60,s=60)
+a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
+b=ot.datasets.get_1D_gauss(n,m=60,s=10)
# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
@@ -37,7 +37,7 @@ def f(G): return 0.5*np.sum(G**2)
def df(G): return G
reg=1e-1
-
+
Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
pl.figure(3)
@@ -47,9 +47,9 @@ ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')
def f(G): return np.sum(G*np.log(G))
def df(G): return np.log(G)+1
-
+
reg=1e-3
-
+
Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
pl.figure(4)
diff --git a/ot/bregman.py b/ot/bregman.py
index a770c5a..b132225 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -144,6 +144,355 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
else:
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False):
+ """
+ Solve the entropic regularization OT problem with log stabilization
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_ but with the log stabilization
+ proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ tau : float
+ thershold for max value in u or v for log scaling
+ warmstart : tible of vectors
+ if given then sarting values for alpha an beta log scalings
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
+ # init data
+ na = len(a)
+ nb = len(b)
+
+
+ cpt = 0
+ if log:
+ log={'err':[]}
+
+ # we assume that no distances are null except those of the diagonal of distances
+ if warmstart is None:
+ alpha,beta=np.zeros(na),np.zeros(nb)
+ else:
+ alpha,beta=warmstart
+ u,v = np.ones(na)/na,np.ones(nb)/nb
+ uprev,vprev=np.zeros(na),np.zeros(nb)
+
+
+ #print reg
+
+
+ def get_K(alpha,beta):
+ """log space computation"""
+ return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
+
+ def get_Gamma(alpha,beta,u,v):
+ """log space gamma computation"""
+ return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb))))
+
+ #print np.min(K)
+
+ K=get_K(alpha,beta)
+ transp = K
+ loop=1
+ cpt = 0
+ err=1
+ while loop:
+
+ if np.abs(u).max()>tau or np.abs(v).max()>tau:
+ alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v)
+ u,v = np.ones(na)/na,np.ones(nb)/nb
+ K=get_K(alpha,beta)
+
+ uprev = u
+ vprev = v
+ v = b/np.dot(K.T,u)
+ u = a/np.dot(K,v)
+
+
+
+ if cpt%print_period==0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+ transp = get_Gamma(alpha,beta,u,v)
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%(print_period*20) ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
+
+ if err<=stopThr:
+ loop=False
+
+ if cpt>=numItermax:
+ loop=False
+
+
+ if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errrors')
+ if cpt!=0:
+ u = uprev
+ v = vprev
+ break
+
+ cpt = cpt +1
+ #print 'err=',err,' cpt=',cpt
+ if log:
+ log['logu']=alpha/reg+np.log(u)
+ log['logv']=beta/reg+np.log(v)
+ log['alpha']=alpha+reg*np.log(u)
+ log['beta']=beta+reg*np.log(v)
+ log['warmstart']=(log['alpha'],log['beta'])
+ return get_Gamma(alpha,beta,u,v),log
+ else:
+ return get_Gamma(alpha,beta,u,v)
+
+def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False):
+ """
+ Solve the entropic regularization optimal transport problem with log
+ stabilization and epsilon scaling.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_ but with the log stabilization
+ proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ tau : float
+ thershold for max value in u or v for log scaling
+ tau : float
+ thershold for max value in u or v for log scaling
+ warmstart : tible of vectors
+ if given then sarting values for alpha an beta log scalings
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterationsin the inner slog stabilized sinkhorn
+ epsilon0 : int, optional
+ first epsilon regularization value (then exponential decrease to reg)
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
+ # init data
+ na = len(a)
+ nb = len(b)
+
+ # nrelative umerical precision with 64 bits
+ numItermin = 35
+ numItermax=max(numItermin,numItermax) # ensure that last velue is exact
+
+
+ cpt = 0
+ if log:
+ log={'err':[]}
+
+ # we assume that no distances are null except those of the diagonal of distances
+ if warmstart is None:
+ alpha,beta=np.zeros(na),np.zeros(nb)
+ else:
+ alpha,beta=warmstart
+
+
+ def get_K(alpha,beta):
+ """log space computation"""
+ return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
+
+ #print np.min(K)
+ def get_reg(n): # exponential decreasing
+ return (epsilon0-reg)*np.exp(-n)+reg
+
+ loop=1
+ cpt = 0
+ err=1
+ while loop:
+
+ regi=get_reg(cpt)
+
+ G,logi=sinkhorn_stabilized(a,b, M, regi, numItermax = numInnerItermax,tau=1e3, stopThr=1e-9,warmstart=(alpha,beta), verbose=False,print_period=20,tau=tau, log=True)
+
+ alpha=logi['alpha']
+ beta=logi['beta']
+
+ if cpt>=numItermax:
+ loop=False
+
+ if cpt%(print_period)==0: # spsion nearly converged
+ # we can speed up the process by checking for the error only all the 10th iterations
+ transp = G
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2+np.linalg.norm((np.sum(transp,axis=1)-a))**2
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%(print_period*10) ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
+ if err<=stopThr and cpt>numItermin:
+ loop=False
+
+ cpt = cpt +1
+ #print 'err=',err,' cpt=',cpt
+ if log:
+ log['alpha']=alpha
+ log['beta']=beta
+ log['warmstart']=(log['alpha'],log['beta'])
+ return G,log
+ else:
+ return G
+
def geometricBar(weights,alldistribT):
"""return the weighted geometric mean of distributions"""
diff --git a/ot/datasets.py b/ot/datasets.py
index c750812..8605691 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -28,7 +28,7 @@ def get_1D_gauss(n,m,s):
"""
x=np.arange(n,dtype=np.float64)
- h=np.exp(-(x-m)**2/(2*s^2))
+ h=np.exp(-(x-m)**2/(2*s**2))
return h/h.sum()