diff options
author | Kilian <kilian.fatras@gmail.com> | 2020-01-07 13:16:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-07 13:16:54 +0100 |
commit | 40154746a5a6cab0b6f17c284967eb8303fcc3f6 (patch) | |
tree | 2622569a36aab5fc2eb2d2adc45a16f861a65dc5 /ot/lp/__init__.py | |
parent | a9bbc2cfdffd22ceee3256102e470df6c25338f3 (diff) | |
parent | c5039bcafde999114283f7e59fb03e176027d740 (diff) |
Merge branch 'master' into emd_dimension
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4cce41c..eabdd3a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False): +def emd(a, b, M, numItermax=100000, log=False, dense=True): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False): log: bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the optimal transportation matrix. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Returns ------- @@ -103,6 +107,7 @@ def emd(a, b, M, numItermax=100000, log=False): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -112,7 +117,12 @@ def emd(a, b, M, numItermax=100000, log=False): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + if dense: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + else: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + result_code_string = check_result(result_code) if log: log = {} @@ -126,7 +136,7 @@ def emd(a, b, M, numItermax=100000, log=False): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, return_matrix=False): + numItermax=100000, log=False, dense=True, return_matrix=False): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -164,6 +174,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), variables. Otherwise returns only the optimal transportation cost. return_matrix: boolean, optional (default=False) If True, returns the optimal transportation matrix in the log. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Returns ------- @@ -220,19 +234,30 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - result_code_string = check_result(resultCode) + if dense: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + else: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + + result_code_string = check_result(result_code) log = {} if return_matrix: log['G'] = G log['u'] = u log['v'] = v log['warning'] = result_code_string - log['result_code'] = resultCode + log['result_code'] = result_code return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + if dense: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + else: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + + result_code_string = check_result(result_code) check_result(result_code) return cost |