summaryrefslogtreecommitdiff
path: root/ot/emd/emd.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/emd/emd.pyx')
-rw-r--r--ot/emd/emd.pyx16
1 files changed, 16 insertions, 0 deletions
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx
index d090cea..e5ac8e0 100644
--- a/ot/emd/emd.pyx
+++ b/ot/emd/emd.pyx
@@ -19,6 +19,22 @@ cdef extern from "EMD.h":
@cython.boundscheck(False)
@cython.wraparound(False)
def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
+
+
+
+ :param a: m weights of the source distribution (must sum to one)
+ :param b: n weights of the target distribution (must sum to one)
+ :param M: m x n cost matrix
+ :type a: np.ndarray
+ :type b: np.ndarray
+ :type M: np.ndarray
+ :return: Optimal transport matrix
+ :rtype: np.ndarray
+
+
+ """
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]