summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 0d2c099..5d93fd6 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -26,7 +26,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
@@ -46,10 +46,22 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+
+ Examples
+ --------
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
References
----------
- .. [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
See Also
@@ -58,6 +70,16 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
ot.optim.cg : General regularized OT
"""
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
# init data
Nini = len(a)
Nfin = len(b)