summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-06-13 14:39:59 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-06-13 14:39:59 +0200
commita632c40b97c51534bfe65fee8b493780df118039 (patch)
tree41d66ab22a5d00052afa94167218a6b5996b63d4 /ot/bregman.py
parent05da582675c89ab20998e1a9505bf3c220e296b8 (diff)
make sinkhorn more general with method selection
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py110
1 files changed, 107 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 00dca88..6b3c68b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -5,8 +5,112 @@ Bregman projections for regularized OT
import numpy as np
+def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
+ u"""
+ Solve the entropic regularization optimal transport problem
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+
+ """
+
+ if method.lower()=='sinkhorn':
+ sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,**kwargs)
+ elif method.lower()=='sinkhorn_stabilized':
+ sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower()=='sinkhorn_epsilon_scaling':
+ sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ else:
+ print('Warning : unknown method using classic Sinkhorn Knopp')
+ sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
+
+ return sink()
+
+
+
+
-def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False):
+def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
"""
Solve the entropic regularization optimal transport problem
@@ -147,7 +251,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
else:
return u.reshape((-1,1))*K*v.reshape((1,-1))
-def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False):
+def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
"""
Solve the entropic regularization OT problem with log stabilization
@@ -331,7 +435,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
else:
return get_Gamma(alpha,beta,u,v)
-def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False):
+def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs):
"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.