summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md5
-rw-r--r--examples/plot_compute_emd.py4
-rw-r--r--ot/__init__.py8
-rw-r--r--ot/bregman.py144
4 files changed, 133 insertions, 28 deletions
diff --git a/README.md b/README.md
index 4aa4cc5..d53387b 100644
--- a/README.md
+++ b/README.md
@@ -83,14 +83,15 @@ import ot
# a,b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
Wd=ot.emd2(a,b,M) # exact linear program
+Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
# if b is a matrix compute all distances to a and return a vector
```
* Compute OT matrix
```python
# a,b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
-Totp=ot.emd(a,b,M) # exact linear program
-Totp_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+T=ot.emd(a,b,M) # exact linear program
+T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
```
* Compute Wasserstein barycenter
```python
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index c7063e8..f2cdc35 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -61,8 +61,8 @@ pl.legend()
#%%
reg=1e-2
-d_sinkhorn=ot.sinkhorn(a,B,M,reg)
-d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
+d_sinkhorn=ot.sinkhorn2(a,B,M,reg)
+d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg)
pl.figure(2)
pl.clf()
diff --git a/ot/__init__.py b/ot/__init__.py
index b2af88b..4220148 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -16,14 +16,14 @@ from . import da
# OT functions
from .lp import emd, emd2
-from .bregman import sinkhorn, barycenter
+from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.2"
+__version__ = "0.3"
-__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp',
- 'plot', 'tic', 'toc', 'toq',
+__all__ = ["emd", "emd2", "sinkhorn","sinkhorn2", "utils", 'datasets',
+ 'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/bregman.py b/ot/bregman.py
index 68be01c..0d68602 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -41,7 +41,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
Regularization term >0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -91,7 +91,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
-
+
if method.lower()=='sinkhorn':
sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,**kwargs)
@@ -100,15 +100,119 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower()=='sinkhorn_epsilon_scaling':
sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
-
+
return sink()
+
+def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
+ u"""
+ Solve the entropic regularization optimal transport problem and return the loss
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : (nt) ndarray or float
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn2(a,b,M,1)
+ array([ 0.26894142])
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+
+ """
+
+ if method.lower()=='sinkhorn':
+ sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,**kwargs)
+ elif method.lower()=='sinkhorn_stabilized':
+ sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower()=='sinkhorn_epsilon_scaling':
+ sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ else:
+ print('Warning : unknown method using classic Sinkhorn Knopp')
+ sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
+
+ b=np.asarray(b,dtype=np.float64)
+ if len(b.shape)<2:
+ b=b.reshape((-1,1))
+
+ return sink()
+
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
"""
@@ -189,23 +293,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
a=np.asarray(a,dtype=np.float64)
b=np.asarray(b,dtype=np.float64)
M=np.asarray(M,dtype=np.float64)
-
+
if len(a)==0:
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
if len(b)==0:
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
-
+
# init data
Nini = len(a)
Nfin = len(b)
-
+
if len(b.shape)>1:
nbb=b.shape[1]
else:
nbb=0
-
+
if log:
log={'err':[]}
@@ -217,7 +321,7 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
else:
u = np.ones(Nini)/Nini
v = np.ones(Nfin)/Nfin
-
+
#print(reg)
@@ -261,23 +365,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
if log:
log['u']=u
log['v']=v
-
- if nbb: #return only loss
+
+ if nbb: #return only loss
res=np.zeros((nbb))
for i in range(nbb):
res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
if log:
return res,log
else:
- return res
-
+ return res
+
else: # return OT matrix
-
+
if log:
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
else:
return u.reshape((-1,1))*K*v.reshape((1,-1))
-
+
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
"""
@@ -393,7 +497,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
alpha,beta=np.zeros(na),np.zeros(nb)
else:
alpha,beta=warmstart
-
+
if nbb:
u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
else:
@@ -420,7 +524,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
uprev = u
vprev = v
-
+
# sinkhorn update
v = b/(np.dot(K.T,u)+1e-16)
u = a/(np.dot(K,v)+1e-16)
@@ -471,8 +575,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
break
cpt = cpt +1
-
-
+
+
#print('err=',err,' cpt=',cpt)
if log:
log['logu']=alpha/reg+np.log(u)
@@ -493,7 +597,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
res=np.zeros((nbb))
for i in range(nbb):
res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
- return res
+ return res
else:
return get_Gamma(alpha,beta,u,v)