diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2020-01-27 10:59:58 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2020-01-27 10:59:58 +0100 |
commit | 9a9b3547837eac56349ce8df92bb5b0565daa2d6 (patch) | |
tree | e898dde7d65c6414cbb9bf1604ec7419c07dd697 /ot/lp | |
parent | e5196fa7a8c493b831fd5dac52a89bbf29e7b0e6 (diff) |
correct emd2 and add centering for dual potentials
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index a771ce4..aa3166f 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -264,6 +264,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) @@ -271,6 +274,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) @@ -287,7 +293,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, dense=True, return_matrix=False): + numItermax=100000, log=False, dense=True, return_matrix=False, + center_dual=True): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -329,6 +336,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :ref:`center_ot_dual`. Returns ------- @@ -383,14 +393,23 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + asel = a != 0 + if log or return_matrix: def f(b): + bsel = b != 0 if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + result_code_string = check_result(result_code) log = {} if return_matrix: @@ -402,12 +421,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return [cost, log] else: def f(b): + bsel = b != 0 if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + result_code_string = check_result(result_code) check_result(result_code) return cost |