summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
commitf63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (patch)
tree96dd2a29842c86a3e3875feba1e8fa8ad3076eb7 /ot/lp/__init__.py
parent5a6b226de20624b51c2ff98bc30e5611a7a788c7 (diff)
EMD 1d without doc
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py43
1 files changed, 38 insertions, 5 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 02cbd8c..49ded5b 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -14,12 +14,12 @@ import numpy as np
from .import cvx
# import compiled emd
-from .emd_wrap import emd_c, check_result
+from .emd_wrap import emd_c, check_result, emd_1d_sorted
from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted']
def emd(a, b, M, numItermax=100000, log=False):
@@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- # if empty array given then use unifor distributions
+ # if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
@@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- # if empty array given then use unifor distributions
+ # if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
@@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
- return X \ No newline at end of file
+ return X
+
+
+def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
+ """Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+ """
+ assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \
+ "with monodimensional data"
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+
+ # if empty array given then use uniform distributions
+ if len(a) == 0:
+ a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
+ if len(b) == 0:
+ b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+
+ perm_a = np.argsort(x_a.reshape((-1, )))
+ perm_b = np.argsort(x_b.reshape((-1, )))
+ inv_perm_a = np.argsort(perm_a)
+ inv_perm_b = np.argsort(perm_b)
+
+ M = dist(x_a[perm_a], x_b[perm_b], metric=metric)
+
+ G_sorted, cost = emd_1d_sorted(a, b, M)
+ G = G_sorted[inv_perm_a, :][:, inv_perm_b]
+ if log:
+ log = {}
+ log['cost'] = cost
+ return G, log
+ return G \ No newline at end of file