summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py26
-rw-r--r--ot/lp/__init__.py30
2 files changed, 46 insertions, 10 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 0d2c099..5d93fd6 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -26,7 +26,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
@@ -46,10 +46,22 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+
+ Examples
+ --------
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
References
----------
- .. [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
See Also
@@ -58,6 +70,16 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
ot.optim.cg : General regularized OT
"""
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
# init data
Nini = len(a)
Nfin = len(b)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 568e370..72b4cb8 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -7,7 +7,6 @@ def emd(a,b,M):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
- gamm=emd(a,b,M)
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -21,6 +20,8 @@ def emd(a,b,M):
- M is the metric cost matrix
- a and b are the sample weights
+
+ Uses the algorithm proposed in [1]_
Parameters
----------
@@ -31,11 +32,17 @@ def emd(a,b,M):
M : (ns,nt) ndarray, float64
loss matrix
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+
Examples
--------
- Simple example with obvious solution. The function :func:emd accepts lists and
- perform automatic conversion tu numpy arrays
+ Simple example with obvious solution. The function emd accepts lists and
+ perform automatic conversion to numpy arrays
>>> a=[.5,.5]
>>> b=[.5,.5]
@@ -43,15 +50,22 @@ def emd(a,b,M):
>>> ot.emd(a,b,M)
array([[ 0.5, 0. ],
[ 0. , 0.5]])
+
+ References
+ ----------
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+
"""
a=np.asarray(a,dtype=np.float64)
b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
if len(a)==0:
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]