summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorKilian <kilian.fatras@gmail.com>2020-01-07 13:16:54 +0100
committerGitHub <noreply@github.com>2020-01-07 13:16:54 +0100
commit40154746a5a6cab0b6f17c284967eb8303fcc3f6 (patch)
tree2622569a36aab5fc2eb2d2adc45a16f861a65dc5 /ot/lp/__init__.py
parenta9bbc2cfdffd22ceee3256102e470df6c25338f3 (diff)
parentc5039bcafde999114283f7e59fb03e176027d740 (diff)
Merge branch 'master' into emd_dimension
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py39
1 files changed, 32 insertions, 7 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4cce41c..eabdd3a 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
-def emd(a, b, M, numItermax=100000, log=False):
+def emd(a, b, M, numItermax=100000, log=False, dense=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False):
log: bool, optional (default=False)
If True, returns a dictionary containing the cost and dual
variables. Otherwise returns only the optimal transportation matrix.
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format.
Returns
-------
@@ -103,6 +107,7 @@ def emd(a, b, M, numItermax=100000, log=False):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
+
# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -112,7 +117,12 @@ def emd(a, b, M, numItermax=100000, log=False):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ if dense:
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ else:
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+
result_code_string = check_result(result_code)
if log:
log = {}
@@ -126,7 +136,7 @@ def emd(a, b, M, numItermax=100000, log=False):
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- numItermax=100000, log=False, return_matrix=False):
+ numItermax=100000, log=False, dense=True, return_matrix=False):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -164,6 +174,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
variables. Otherwise returns only the optimal transportation cost.
return_matrix: boolean, optional (default=False)
If True, returns the optimal transportation matrix in the log.
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format.
Returns
-------
@@ -220,19 +234,30 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if log or return_matrix:
def f(b):
- G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
- result_code_string = check_result(resultCode)
+ if dense:
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ else:
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+
+ result_code_string = check_result(result_code)
log = {}
if return_matrix:
log['G'] = G
log['u'] = u
log['v'] = v
log['warning'] = result_code_string
- log['result_code'] = resultCode
+ log['result_code'] = result_code
return [cost, log]
else:
def f(b):
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ if dense:
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ else:
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+
+ result_code_string = check_result(result_code)
check_result(result_code)
return cost