summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
commitf63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (patch)
tree96dd2a29842c86a3e3875feba1e8fa8ad3076eb7 /ot
parent5a6b226de20624b51c2ff98bc30e5611a7a788c7 (diff)
EMD 1d without doc
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/lp/__init__.py43
-rw-r--r--ot/lp/emd_wrap.pyx35
3 files changed, 75 insertions, 7 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index b74b924..5d5b700 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -22,7 +22,7 @@ from . import smooth
from . import stochastic
# OT functions
-from .lp import emd, emd2
+from .lp import emd, emd2, emd_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
@@ -31,6 +31,6 @@ from .utils import dist, unif, tic, toc, toq
__version__ = "0.5.1"
-__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
+__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 02cbd8c..49ded5b 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -14,12 +14,12 @@ import numpy as np
from .import cvx
# import compiled emd
-from .emd_wrap import emd_c, check_result
+from .emd_wrap import emd_c, check_result, emd_1d_sorted
from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted']
def emd(a, b, M, numItermax=100000, log=False):
@@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- # if empty array given then use unifor distributions
+ # if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
@@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- # if empty array given then use unifor distributions
+ # if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
@@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
- return X \ No newline at end of file
+ return X
+
+
+def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
+ """Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+ """
+ assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \
+ "with monodimensional data"
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+
+ # if empty array given then use uniform distributions
+ if len(a) == 0:
+ a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
+ if len(b) == 0:
+ b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+
+ perm_a = np.argsort(x_a.reshape((-1, )))
+ perm_b = np.argsort(x_b.reshape((-1, )))
+ inv_perm_a = np.argsort(perm_a)
+ inv_perm_b = np.argsort(perm_b)
+
+ M = dist(x_a[perm_a], x_b[perm_b], metric=metric)
+
+ G_sorted, cost = emd_1d_sorted(a, b, M)
+ G = G_sorted[inv_perm_a, :][:, inv_perm_b]
+ if log:
+ log = {}
+ log['cost'] = cost
+ return G, log
+ return G \ No newline at end of file
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 83ee6aa..a3d189d 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -93,3 +93,38 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
return G, cost, alpha, beta, result_code
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
+ np.ndarray[double, ndim=1, mode="c"] v_weights,
+ np.ndarray[double, ndim=2, mode="c"] M):
+ r"""
+ Roro's stuff
+ """
+ cdef double cost = 0.
+ cdef int n = u_weights.shape[0]
+ cdef int m = v_weights.shape[0]
+
+ cdef int i = 0
+ cdef double w_i = u_weights[0]
+ cdef int j = 0
+ cdef double w_j = v_weights[0]
+
+ cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
+ dtype=np.float64)
+ while i < n and j < m:
+ if w_i < w_j or j == m - 1:
+ cost += M[i, j] * w_i
+ G[i, j] = w_i
+ i += 1
+ w_j -= w_i
+ w_i = u_weights[i]
+ else:
+ cost += M[i, j] * w_j
+ G[i, j] = w_j
+ j += 1
+ w_i -= w_j
+ w_j = v_weights[j]
+ return G, cost \ No newline at end of file