diff options
Diffstat (limited to 'ot/emd/emd.pyx')
-rw-r--r-- | ot/emd/emd.pyx | 44 |
1 files changed, 34 insertions, 10 deletions
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx index e5ac8e0..753b195 100644 --- a/ot/emd/emd.pyx +++ b/ot/emd/emd.pyx @@ -22,17 +22,35 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode= """ Solves the Earth Movers distance problem and returns the optimal transport matrix + gamm=emd(a,b,M) + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F - - :param a: m weights of the source distribution (must sum to one) - :param b: n weights of the target distribution (must sum to one) - :param M: m x n cost matrix - :type a: np.ndarray - :type b: np.ndarray - :type M: np.ndarray - :return: Optimal transport matrix - :rtype: np.ndarray - + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the metric cost matrix + - a and b are the sample weights + + Parameters + ---------- + a : (ns,) ndarray + samples in the source domain (uniform waigth if empty) + b : (nt,) ndarray + samples in the target domain (uniform waigth if empty) + M : (ns,nt) ndarray + loss matrix + + + Returns + ------- + gamma: (ns x nt) ndarray + Optimal transportation matrix for the given parameters """ cdef int n1= M.shape[0] @@ -40,6 +58,12 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode= cdef float cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + + if not len(a): + a=np.ones((n1,))/n1 + + if not len(b): + b=np.ones((n2,))/n2 # calling the function EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost) |