diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 09:14:19 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 09:14:19 +0200 |
commit | 0e8363ca13034537ca7bc64acd982f39b7a42123 (patch) | |
tree | 18adc06132943c2c9d0ba731a33a9b7174624b2f /ot/lp/__init__.py | |
parent | cf2d92e151f816e6ddcfc4b64cbda1f8f7bde9df (diff) |
doc sinkhorn
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 568e370..72b4cb8 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -7,7 +7,6 @@ def emd(a,b,M): """ Solves the Earth Movers distance problem and returns the optimal transport matrix - gamm=emd(a,b,M) .. math:: \gamma = arg\min_\gamma <\gamma,M>_F @@ -21,6 +20,8 @@ def emd(a,b,M): - M is the metric cost matrix - a and b are the sample weights + + Uses the algorithm proposed in [1]_ Parameters ---------- @@ -31,11 +32,17 @@ def emd(a,b,M): M : (ns,nt) ndarray, float64 loss matrix + Returns + ------- + gamma: (ns x nt) ndarray + Optimal transportation matrix for the given parameters + + Examples -------- - Simple example with obvious solution. The function :func:emd accepts lists and - perform automatic conversion tu numpy arrays + Simple example with obvious solution. The function emd accepts lists and + perform automatic conversion to numpy arrays >>> a=[.5,.5] >>> b=[.5,.5] @@ -43,15 +50,22 @@ def emd(a,b,M): >>> ot.emd(a,b,M) array([[ 0.5, 0. ], [ 0. , 0.5]]) + + References + ---------- - Returns - ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ a=np.asarray(a,dtype=np.float64) b=np.asarray(b,dtype=np.float64) + M=np.asarray(M,dtype=np.float64) if len(a)==0: a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] |