summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2020-01-27 10:59:58 +0100
committerRémi Flamary <remi.flamary@gmail.com>2020-01-27 10:59:58 +0100
commit9a9b3547837eac56349ce8df92bb5b0565daa2d6 (patch)
treee898dde7d65c6414cbb9bf1604ec7419c07dd697 /ot/lp
parente5196fa7a8c493b831fd5dac52a89bbf29e7b0e6 (diff)
correct emd2 and add centering for dual potentials
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index a771ce4..aa3166f 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -264,6 +264,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
@@ -271,6 +274,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
@@ -287,7 +293,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- numItermax=100000, log=False, dense=True, return_matrix=False):
+ numItermax=100000, log=False, dense=True, return_matrix=False,
+ center_dual=True):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -329,6 +336,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's `coo_matrix`
format.
+ center_dual: boolean, optional (default=True)
+ If True, centers the dual potential using function
+ :ref:`center_ot_dual`.
Returns
-------
@@ -383,14 +393,23 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ asel = a != 0
+
if log or return_matrix:
def f(b):
+ bsel = b != 0
if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
else:
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
result_code_string = check_result(result_code)
log = {}
if return_matrix:
@@ -402,12 +421,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return [cost, log]
else:
def f(b):
+ bsel = b != 0
if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
else:
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
result_code_string = check_result(result_code)
check_result(result_code)
return cost