summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2020-04-23 13:03:28 +0200
committerGitHub <noreply@github.com>2020-04-23 13:03:28 +0200
commitef12867f1425ee86b3cfddef4287b52d46114e83 (patch)
tree38e023c5561b1669f4d8e602feb6728f51e1b359 /ot/lp/__init__.py
parentbacb0b992aa4e1ba7e5fd0beb0bf9617c801f833 (diff)
[WIP] Issue with sparse emd and adding tests on macos (#158)
* First commit-warning removal * remove dense feature * pep8 * pep8 * EMD.h * pep8 again * tic toc tolerance Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py45
1 files changed, 11 insertions, 34 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 8d1baa0..ad390c5 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -172,7 +172,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
return center_ot_dual(alpha, beta, a, b)
-def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
+def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -207,10 +207,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
log: bool, optional (default=False)
If True, returns a dictionary containing the cost and dual
variables. Otherwise returns only the optimal transportation matrix.
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
@@ -267,25 +263,14 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
asel = a != 0
bsel = b != 0
- if dense:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
-
- if center_dual:
- u, v = center_ot_dual(u, v, a, b)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
- if np.any(~asel) or np.any(~bsel):
- u, v = estimate_dual_null_weights(u, v, a, b, M)
-
- else:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
-
- if center_dual:
- u, v = center_ot_dual(u, v, a, b)
-
- if np.any(~asel) or np.any(~bsel):
- u, v = estimate_dual_null_weights(u, v, a, b, M)
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
result_code_string = check_result(result_code)
if log:
log = {}
@@ -299,7 +284,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- numItermax=100000, log=False, dense=True, return_matrix=False,
+ numItermax=100000, log=False, return_matrix=False,
center_dual=True):
r"""Solves the Earth Movers distance problem and returns the loss
@@ -404,11 +389,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if log or return_matrix:
def f(b):
bsel = b != 0
- if dense:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
- else:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -428,11 +410,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
else:
def f(b):
bsel = b != 0
- if dense:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
- else:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -440,7 +418,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
- result_code_string = check_result(result_code)
check_result(result_code)
return cost