summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:15:42 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:15:42 +0200
commitcada9a3019997e8efb95d96c86985110f1e937b9 (patch)
treebc1cacefe01438ad44b0c8021bb0052d9190cafd /ot/lp/__init__.py
parent15b21611a3a93043d30c4eaaf9d622200453a884 (diff)
Sparse G matrix for EMD1d + standard metrics computed without cdist
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py54
1 files changed, 40 insertions, 14 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index c4457dc..decff29 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -10,6 +10,7 @@ Solvers for the original linear program OT problem
import multiprocessing
import numpy as np
+from scipy.sparse import coo_matrix
from .import cvx
@@ -19,7 +20,8 @@ from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted']
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+ 'emd_1d', 'emd2_1d']
def emd(a, b, M, numItermax=100000, log=False):
@@ -311,16 +313,20 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X
-def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
+def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
"""Solves the Earth Movers distance problem between 1d measures and returns
the OT matrix
"""
- assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \
- "with monodimensional data"
-
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
+ x_a = np.asarray(x_a, dtype=np.float64)
+ x_b = np.asarray(x_b, dtype=np.float64)
+
+ assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+ assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -328,16 +334,36 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
if len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
- perm_a = np.argsort(x_a.reshape((-1, )))
- perm_b = np.argsort(x_b.reshape((-1, )))
- inv_perm_a = np.argsort(perm_a)
- inv_perm_b = np.argsort(perm_b)
-
- G_sorted, cost = emd_1d_sorted(a, b, x_a[perm_a], x_b[perm_b],
- metric=metric)
- G = G_sorted[inv_perm_a, :][:, inv_perm_b]
+ x_a_1d = x_a.reshape((-1, ))
+ x_b_1d = x_b.reshape((-1, ))
+ perm_a = np.argsort(x_a_1d)
+ perm_b = np.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(a, b,
+ x_a_1d[perm_a], x_b_1d[perm_b],
+ metric=metric)
+ G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
+ shape=(a.shape[0], b.shape[0]))
+ if dense:
+ G = G.todense()
if log:
log = {}
log['cost'] = cost
return G, log
- return G \ No newline at end of file
+ return G
+
+
+def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
+ """Solves the Earth Movers distance problem between 1d measures and returns
+ the loss
+
+ """
+ # If we do not return G (log==False), then we should not to cast it to dense
+ # (useless overhead)
+ G, log_emd = emd_1d(a=a, b=b, x_a=x_a, x_b=x_b, metric=metric,
+ dense=dense and log, log=True)
+ cost = log_emd['cost']
+ if log:
+ log_emd = {'G': G}
+ return cost, log_emd
+ return cost \ No newline at end of file