summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-07 16:51:42 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-07 16:51:42 +0100
commit8b41e141e8ce4ad14a458cb363f46d3176644116 (patch)
tree84110632f983344feaa757b228c67a1b9c6c3aa1 /ot
parente485078116660e53b47aa1f7b96288b9b413c3eb (diff)
add log and epsilon scaling stabilizations
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py349
-rw-r--r--ot/datasets.py2
2 files changed, 350 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index a770c5a..b132225 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -144,6 +144,355 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
else:
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False):
+ """
+ Solve the entropic regularization OT problem with log stabilization
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_ but with the log stabilization
+ proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ tau : float
+ thershold for max value in u or v for log scaling
+ warmstart : tible of vectors
+ if given then sarting values for alpha an beta log scalings
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
+ # init data
+ na = len(a)
+ nb = len(b)
+
+
+ cpt = 0
+ if log:
+ log={'err':[]}
+
+ # we assume that no distances are null except those of the diagonal of distances
+ if warmstart is None:
+ alpha,beta=np.zeros(na),np.zeros(nb)
+ else:
+ alpha,beta=warmstart
+ u,v = np.ones(na)/na,np.ones(nb)/nb
+ uprev,vprev=np.zeros(na),np.zeros(nb)
+
+
+ #print reg
+
+
+ def get_K(alpha,beta):
+ """log space computation"""
+ return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
+
+ def get_Gamma(alpha,beta,u,v):
+ """log space gamma computation"""
+ return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb))))
+
+ #print np.min(K)
+
+ K=get_K(alpha,beta)
+ transp = K
+ loop=1
+ cpt = 0
+ err=1
+ while loop:
+
+ if np.abs(u).max()>tau or np.abs(v).max()>tau:
+ alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v)
+ u,v = np.ones(na)/na,np.ones(nb)/nb
+ K=get_K(alpha,beta)
+
+ uprev = u
+ vprev = v
+ v = b/np.dot(K.T,u)
+ u = a/np.dot(K,v)
+
+
+
+ if cpt%print_period==0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+ transp = get_Gamma(alpha,beta,u,v)
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%(print_period*20) ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
+
+ if err<=stopThr:
+ loop=False
+
+ if cpt>=numItermax:
+ loop=False
+
+
+ if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errrors')
+ if cpt!=0:
+ u = uprev
+ v = vprev
+ break
+
+ cpt = cpt +1
+ #print 'err=',err,' cpt=',cpt
+ if log:
+ log['logu']=alpha/reg+np.log(u)
+ log['logv']=beta/reg+np.log(v)
+ log['alpha']=alpha+reg*np.log(u)
+ log['beta']=beta+reg*np.log(v)
+ log['warmstart']=(log['alpha'],log['beta'])
+ return get_Gamma(alpha,beta,u,v),log
+ else:
+ return get_Gamma(alpha,beta,u,v)
+
+def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False):
+ """
+ Solve the entropic regularization optimal transport problem with log
+ stabilization and epsilon scaling.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_ but with the log stabilization
+ proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ tau : float
+ thershold for max value in u or v for log scaling
+ tau : float
+ thershold for max value in u or v for log scaling
+ warmstart : tible of vectors
+ if given then sarting values for alpha an beta log scalings
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterationsin the inner slog stabilized sinkhorn
+ epsilon0 : int, optional
+ first epsilon regularization value (then exponential decrease to reg)
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
+ # init data
+ na = len(a)
+ nb = len(b)
+
+ # nrelative umerical precision with 64 bits
+ numItermin = 35
+ numItermax=max(numItermin,numItermax) # ensure that last velue is exact
+
+
+ cpt = 0
+ if log:
+ log={'err':[]}
+
+ # we assume that no distances are null except those of the diagonal of distances
+ if warmstart is None:
+ alpha,beta=np.zeros(na),np.zeros(nb)
+ else:
+ alpha,beta=warmstart
+
+
+ def get_K(alpha,beta):
+ """log space computation"""
+ return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
+
+ #print np.min(K)
+ def get_reg(n): # exponential decreasing
+ return (epsilon0-reg)*np.exp(-n)+reg
+
+ loop=1
+ cpt = 0
+ err=1
+ while loop:
+
+ regi=get_reg(cpt)
+
+ G,logi=sinkhorn_stabilized(a,b, M, regi, numItermax = numInnerItermax,tau=1e3, stopThr=1e-9,warmstart=(alpha,beta), verbose=False,print_period=20,tau=tau, log=True)
+
+ alpha=logi['alpha']
+ beta=logi['beta']
+
+ if cpt>=numItermax:
+ loop=False
+
+ if cpt%(print_period)==0: # spsion nearly converged
+ # we can speed up the process by checking for the error only all the 10th iterations
+ transp = G
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2+np.linalg.norm((np.sum(transp,axis=1)-a))**2
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%(print_period*10) ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
+ if err<=stopThr and cpt>numItermin:
+ loop=False
+
+ cpt = cpt +1
+ #print 'err=',err,' cpt=',cpt
+ if log:
+ log['alpha']=alpha
+ log['beta']=beta
+ log['warmstart']=(log['alpha'],log['beta'])
+ return G,log
+ else:
+ return G
+
def geometricBar(weights,alldistribT):
"""return the weighted geometric mean of distributions"""
diff --git a/ot/datasets.py b/ot/datasets.py
index c750812..8605691 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -28,7 +28,7 @@ def get_1D_gauss(n,m,s):
"""
x=np.arange(n,dtype=np.float64)
- h=np.exp(-(x-m)**2/(2*s^2))
+ h=np.exp(-(x-m)**2/(2*s**2))
return h/h.sum()