From 4a6883e0ce2fd9f3edd374d54c4c219d876ceb76 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 2 Dec 2019 09:37:54 +0100 Subject: nothing explodes and it compiles --- ot/lp/EMD.h | 4 ++++ ot/lp/EMD_wrapper.cpp | 2 +- ot/lp/__init__.py | 29 ++++++++++++++++++++++------- ot/lp/emd_wrap.pyx | 38 +++++++++++++++++++++++++++++++++----- 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index f42e222..bc513d2 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -32,4 +32,8 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); +int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, + long *iG, long *jG, double *G, + double* alpha, double* beta, double *cost, int maxIter); + #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 65fa80f..3ca7319 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -108,7 +108,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - int *iG, int *jG, double *G, + long *iG, long *jG, double *G, double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! int n, m, i, cur; diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0c92810..4fec7d9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False): +def emd(a, b, M, numItermax=100000, log=False, sparse=False): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -109,7 +109,12 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + if sparse: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + result_code_string = check_result(result_code) if log: log = {} @@ -123,7 +128,7 @@ def emd(a, b, M, numItermax=100000, log=False): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, return_matrix=False): + numItermax=100000, log=False, sparse=False, return_matrix=False): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -214,19 +219,29 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - result_code_string = check_result(resultCode) + + if sparse: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + + result_code_string = check_result(result_code) log = {} if return_matrix: log['G'] = G log['u'] = u log['v'] = v log['warning'] = result_code_string - log['result_code'] = resultCode + log['result_code'] = result_code return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + if sparse: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2b6c495..345cb66 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -20,6 +20,9 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, + long *iG, long *jG, double *G, + double* alpha, double* beta, double *cost, int maxIter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -39,7 +42,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint sparse): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -82,12 +85,18 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod """ cdef int n1= M.shape[0] cdef int n2= M.shape[1] + cdef int nmax=n1+n2-1 + cdef int result_code = 0 cdef double cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) + + cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int) + cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int) if not len(a): a=np.ones((n1,))/n1 @@ -95,10 +104,29 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 - # calling the function - cdef int result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + if sparse: + + Gv=np.zeros(nmax) + iG=np.zeros(nmax,dtype=np.int) + jG=np.zeros(nmax,dtype=np.int) + + + result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, G.data, alpha.data, beta.data, &cost, max_iter) + + + return Gv, iG, jG, cost, alpha, beta, result_code + + + else: + + + G=np.zeros([n1, n2]) + + + # calling the function + result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) - return G, cost, alpha, beta, result_code + return G, cost, alpha, beta, result_code @cython.boundscheck(False) -- cgit v1.2.3