summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
commitf63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (patch)
tree96dd2a29842c86a3e3875feba1e8fa8ad3076eb7 /ot/lp/emd_wrap.pyx
parent5a6b226de20624b51c2ff98bc30e5611a7a788c7 (diff)
EMD 1d without doc
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx35
1 files changed, 35 insertions, 0 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 83ee6aa..a3d189d 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -93,3 +93,38 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
return G, cost, alpha, beta, result_code
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
+ np.ndarray[double, ndim=1, mode="c"] v_weights,
+ np.ndarray[double, ndim=2, mode="c"] M):
+ r"""
+ Roro's stuff
+ """
+ cdef double cost = 0.
+ cdef int n = u_weights.shape[0]
+ cdef int m = v_weights.shape[0]
+
+ cdef int i = 0
+ cdef double w_i = u_weights[0]
+ cdef int j = 0
+ cdef double w_j = v_weights[0]
+
+ cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
+ dtype=np.float64)
+ while i < n and j < m:
+ if w_i < w_j or j == m - 1:
+ cost += M[i, j] * w_i
+ G[i, j] = w_i
+ i += 1
+ w_j -= w_i
+ w_i = u_weights[i]
+ else:
+ cost += M[i, j] * w_j
+ G[i, j] = w_j
+ j += 1
+ w_i -= w_j
+ w_j = v_weights[j]
+ return G, cost \ No newline at end of file