summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-12-18 10:15:30 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-12-18 10:15:30 +0100
commit3cb03158c42dde141d6f33973ea6e3394b9dc3d4 (patch)
tree272412760f0e55b9a354b7c43ddc71a2a6320a69 /ot
parenta4afee871d8e9d5db68228d1ed5bf4853eedc294 (diff)
cleanup variable name dense
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/__init__.py30
-rw-r--r--ot/lp/emd_wrap.pyx26
2 files changed, 27 insertions, 29 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d476071..bb9829a 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -107,7 +107,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- sparse= not dense
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -115,11 +114,11 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- if sparse:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ if dense:
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
else:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
result_code_string = check_result(result_code)
if log:
@@ -217,8 +216,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- sparse=not dense
-
# problem with pikling Forks
if sys.platform.endswith('win32'):
processes=1
@@ -231,12 +228,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if log or return_matrix:
def f(b):
-
- if sparse:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ if dense:
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
else:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
result_code_string = check_result(result_code)
log = {}
@@ -249,11 +245,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return [cost, log]
else:
def f(b):
- if sparse:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ if dense:
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
else:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+
+ result_code_string = check_result(result_code)
check_result(result_code)
return cost
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 4e3586d..636a9e3 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -46,7 +46,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint sparse):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -110,8 +110,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
if not len(b):
b=np.ones((n2,))/n2
- if sparse:
+ if dense:
+ # init OT matrix
+ G=np.zeros([n1, n2])
+
+ # calling the function
+ result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+
+ return G, cost, alpha, beta, result_code
+
+ else:
+
+ # init sparse OT matrix
Gv=np.zeros(nmax)
iG=np.zeros(nmax,dtype=np.int)
jG=np.zeros(nmax,dtype=np.int)
@@ -123,17 +134,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code
- else:
-
-
- G=np.zeros([n1, n2])
-
-
- # calling the function
- result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
-
- return G, cost, alpha, beta, result_code
-
@cython.boundscheck(False)
@cython.wraparound(False)