diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-18 10:15:30 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-18 10:15:30 +0100 |
commit | 3cb03158c42dde141d6f33973ea6e3394b9dc3d4 (patch) | |
tree | 272412760f0e55b9a354b7c43ddc71a2a6320a69 /ot/lp/emd_wrap.pyx | |
parent | a4afee871d8e9d5db68228d1ed5bf4853eedc294 (diff) |
cleanup variable name dense
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r-- | ot/lp/emd_wrap.pyx | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4e3586d..636a9e3 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -46,7 +46,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint sparse): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -110,8 +110,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 - if sparse: + if dense: + # init OT matrix + G=np.zeros([n1, n2]) + + # calling the function + result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + + return G, cost, alpha, beta, result_code + + else: + + # init sparse OT matrix Gv=np.zeros(nmax) iG=np.zeros(nmax,dtype=np.int) jG=np.zeros(nmax,dtype=np.int) @@ -123,17 +134,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code - else: - - - G=np.zeros([n1, n2]) - - - # calling the function - result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) - - return G, cost, alpha, beta, result_code - @cython.boundscheck(False) @cython.wraparound(False) |