diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-18 10:15:30 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-18 10:15:30 +0100 |
commit | 3cb03158c42dde141d6f33973ea6e3394b9dc3d4 (patch) | |
tree | 272412760f0e55b9a354b7c43ddc71a2a6320a69 /ot | |
parent | a4afee871d8e9d5db68228d1ed5bf4853eedc294 (diff) |
cleanup variable name dense
Diffstat (limited to 'ot')
-rw-r--r-- | ot/lp/__init__.py | 30 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 26 |
2 files changed, 27 insertions, 29 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d476071..bb9829a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -107,7 +107,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - sparse= not dense # if empty array given then use uniform distributions if len(a) == 0: @@ -115,11 +114,11 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - if sparse: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if dense: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) else: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) result_code_string = check_result(result_code) if log: @@ -217,8 +216,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - sparse=not dense - # problem with pikling Forks if sys.platform.endswith('win32'): processes=1 @@ -231,12 +228,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): - - if sparse: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if dense: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) else: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) result_code_string = check_result(result_code) log = {} @@ -249,11 +245,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return [cost, log] else: def f(b): - if sparse: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if dense: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) else: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + + result_code_string = check_result(result_code) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4e3586d..636a9e3 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -46,7 +46,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint sparse): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -110,8 +110,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 - if sparse: + if dense: + # init OT matrix + G=np.zeros([n1, n2]) + + # calling the function + result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + + return G, cost, alpha, beta, result_code + + else: + + # init sparse OT matrix Gv=np.zeros(nmax) iG=np.zeros(nmax,dtype=np.int) jG=np.zeros(nmax,dtype=np.int) @@ -123,17 +134,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code - else: - - - G=np.zeros([n1, n2]) - - - # calling the function - result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) - - return G, cost, alpha, beta, result_code - @cython.boundscheck(False) @cython.wraparound(False) |