summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 09:14:19 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 09:14:19 +0200
commit0e8363ca13034537ca7bc64acd982f39b7a42123 (patch)
tree18adc06132943c2c9d0ba731a33a9b7174624b2f /ot/lp/__init__.py
parentcf2d92e151f816e6ddcfc4b64cbda1f8f7bde9df (diff)
doc sinkhorn
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 568e370..72b4cb8 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -7,7 +7,6 @@ def emd(a,b,M):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
- gamm=emd(a,b,M)
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -21,6 +20,8 @@ def emd(a,b,M):
- M is the metric cost matrix
- a and b are the sample weights
+
+ Uses the algorithm proposed in [1]_
Parameters
----------
@@ -31,11 +32,17 @@ def emd(a,b,M):
M : (ns,nt) ndarray, float64
loss matrix
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+
Examples
--------
- Simple example with obvious solution. The function :func:emd accepts lists and
- perform automatic conversion tu numpy arrays
+ Simple example with obvious solution. The function emd accepts lists and
+ perform automatic conversion to numpy arrays
>>> a=[.5,.5]
>>> b=[.5,.5]
@@ -43,15 +50,22 @@ def emd(a,b,M):
>>> ot.emd(a,b,M)
array([[ 0.5, 0. ],
[ 0. , 0.5]])
+
+ References
+ ----------
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+
"""
a=np.asarray(a,dtype=np.float64)
b=np.asarray(b,dtype=np.float64)
+ M=np.asarray(M,dtype=np.float64)
if len(a)==0:
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]