summaryrefslogtreecommitdiff
path: root/ot/emd/emd.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/emd/emd.pyx')
-rw-r--r--ot/emd/emd.pyx44
1 files changed, 34 insertions, 10 deletions
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx
index e5ac8e0..753b195 100644
--- a/ot/emd/emd.pyx
+++ b/ot/emd/emd.pyx
@@ -22,17 +22,35 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode=
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
+ gamm=emd(a,b,M)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
-
- :param a: m weights of the source distribution (must sum to one)
- :param b: n weights of the target distribution (must sum to one)
- :param M: m x n cost matrix
- :type a: np.ndarray
- :type b: np.ndarray
- :type M: np.ndarray
- :return: Optimal transport matrix
- :rtype: np.ndarray
-
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray
+ samples in the source domain (uniform waigth if empty)
+ b : (nt,) ndarray
+ samples in the target domain (uniform waigth if empty)
+ M : (ns,nt) ndarray
+ loss matrix
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
"""
cdef int n1= M.shape[0]
@@ -40,6 +58,12 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode=
cdef float cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+
+ if not len(a):
+ a=np.ones((n1,))/n1
+
+ if not len(b):
+ b=np.ones((n2,))/n2
# calling the function
EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)