summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:52:23 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:52:23 +0200
commit15b21611a3a93043d30c4eaaf9d622200453a884 (patch)
tree20ff23c61f35b876e56d435a5c522249f396e3c5 /ot
parentf63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (diff)
EMD 1d without doc made faster
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/__init__.py5
-rw-r--r--ot/lp/emd_wrap.pyx14
2 files changed, 13 insertions, 6 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 49ded5b..c4457dc 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -333,9 +333,8 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
inv_perm_a = np.argsort(perm_a)
inv_perm_b = np.argsort(perm_b)
- M = dist(x_a[perm_a], x_b[perm_b], metric=metric)
-
- G_sorted, cost = emd_1d_sorted(a, b, M)
+ G_sorted, cost = emd_1d_sorted(a, b, x_a[perm_a], x_b[perm_b],
+ metric=metric)
G = G_sorted[inv_perm_a, :][:, inv_perm_b]
if log:
log = {}
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index a3d189d..2966206 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -10,6 +10,8 @@ Cython linker with C solver
import numpy as np
cimport numpy as np
+from ..utils import dist
+
cimport cython
import warnings
@@ -99,7 +101,9 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
@cython.wraparound(False)
def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
np.ndarray[double, ndim=1, mode="c"] v_weights,
- np.ndarray[double, ndim=2, mode="c"] M):
+ np.ndarray[double, ndim=2, mode="c"] u,
+ np.ndarray[double, ndim=2, mode="c"] v,
+ str metric='sqeuclidean'):
r"""
Roro's stuff
"""
@@ -112,17 +116,21 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cdef int j = 0
cdef double w_j = v_weights[0]
+ cdef double m_ij = 0.
+
cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
dtype=np.float64)
while i < n and j < m:
+ m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
+ metric=metric)[0, 0]
if w_i < w_j or j == m - 1:
- cost += M[i, j] * w_i
+ cost += m_ij * w_i
G[i, j] = w_i
i += 1
w_j -= w_i
w_i = u_weights[i]
else:
- cost += M[i, j] * w_j
+ cost += m_ij * w_j
G[i, j] = w_j
j += 1
w_i -= w_j