diff options
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0c92810..4fec7d9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False): +def emd(a, b, M, numItermax=100000, log=False, sparse=False): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -109,7 +109,12 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + if sparse: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + result_code_string = check_result(result_code) if log: log = {} @@ -123,7 +128,7 @@ def emd(a, b, M, numItermax=100000, log=False): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, return_matrix=False): + numItermax=100000, log=False, sparse=False, return_matrix=False): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -214,19 +219,29 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - result_code_string = check_result(resultCode) + + if sparse: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + + result_code_string = check_result(result_code) log = {} if return_matrix: log['G'] = G log['u'] = u log['v'] = v log['warning'] = result_code_string - log['result_code'] = resultCode + log['result_code'] = result_code return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + if sparse: + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse) check_result(result_code) return cost |