summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 15:13:48 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 15:13:48 +0200
commit0d81de9909e8e9eb95858f0a043550b15898f172 (patch)
tree52a456562616b8d015c8985a072bcf8759e84fa7 /ot/lp/__init__.py
parent4ced7428bd2be4ca008f12400afb445c5a6517c8 (diff)
doc da.py
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py77
1 files changed, 38 insertions, 39 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 6c7822a..2adf937 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -1,31 +1,30 @@
+# -*- coding: utf-8 -*-
"""
Solvers for the original linear program OT problem
"""
+import numpy as np
# import compiled emd
from .emd import emd_c
-import numpy as np
-def emd(a,b,M):
- """
- Solves the Earth Movers distance problem and returns the optimal transport matrix
-
-
+
+def emd(a, b, M):
+ """Solves the Earth Movers distance problem and returns the OT matrix
+
+
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+ \gamma^T 1= b
\gamma\geq 0
where :
-
+
- M is the metric cost matrix
- a and b are the sample weights
-
+
Uses the algorithm proposed in [1]_
-
+
Parameters
----------
a : (ns,) ndarray, float64
@@ -33,47 +32,47 @@ def emd(a,b,M):
b : (nt,) ndarray, float64
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
- loss matrix
-
+ loss matrix
+
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
-
+
+
Examples
--------
-
+
Simple example with obvious solution. The function emd accepts lists and
- perform automatic conversion to numpy arrays
-
+ perform automatic conversion to numpy arrays
+
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.emd(a,b,M)
array([[ 0.5, 0. ],
[ 0. , 0.5]])
-
+
References
----------
-
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
-
+
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
+ (2011, December). Displacement interpolation using Lagrangian mass
+ transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
+ 158). ACM.
+
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT
-
-
- """
- a=np.asarray(a,dtype=np.float64)
- b=np.asarray(b,dtype=np.float64)
- M=np.asarray(M,dtype=np.float64)
-
- if len(a)==0:
- a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
- if len(b)==0:
- b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
-
- return emd_c(a,b,M)
+ ot.optim.cg : General regularized OT"""
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
+ return emd_c(a, b, M)