summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx32
1 files changed, 19 insertions, 13 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c167964..42e08f4 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -20,6 +20,7 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -38,7 +39,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -97,8 +98,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
- cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int)
- cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int)
if not len(a):
a=np.ones((n1,))/n1
@@ -111,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
# calling the function
with nogil:
- result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
-
+ if numThreads == 1:
+ result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ else:
+ result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
return G, cost, alpha, beta, result_code
@@ -157,22 +158,22 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cost associated to the optimal transportation
"""
cdef double cost = 0.
- cdef int n = u_weights.shape[0]
- cdef int m = v_weights.shape[0]
+ cdef Py_ssize_t n = u_weights.shape[0]
+ cdef Py_ssize_t m = v_weights.shape[0]
- cdef int i = 0
+ cdef Py_ssize_t i = 0
cdef double w_i = u_weights[0]
- cdef int j = 0
+ cdef Py_ssize_t j = 0
cdef double w_j = v_weights[0]
cdef double m_ij = 0.
cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
dtype=np.float64)
- cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
- dtype=np.int)
- cdef int cur_idx = 0
- while i < n and j < m:
+ cdef np.ndarray[long long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int64)
+ cdef Py_ssize_t cur_idx = 0
+ while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +189,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
i += 1
+ if i == n:
+ break
w_j -= w_i
w_i = u_weights[i]
else:
@@ -196,7 +199,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
j += 1
+ if j == m:
+ break
w_i -= w_j
w_j = v_weights[j]
cur_idx += 1
+ cur_idx += 1
return G[:cur_idx], indices[:cur_idx], cost