summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-20 14:29:56 +0200
commitf63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba (patch)
tree96dd2a29842c86a3e3875feba1e8fa8ad3076eb7
parent5a6b226de20624b51c2ff98bc30e5611a7a788c7 (diff)
EMD 1d without doc
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/lp/__init__.py43
-rw-r--r--ot/lp/emd_wrap.pyx35
-rw-r--r--test/test_ot.py26
4 files changed, 101 insertions, 7 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index b74b924..5d5b700 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -22,7 +22,7 @@ from . import smooth
from . import stochastic
# OT functions
-from .lp import emd, emd2
+from .lp import emd, emd2, emd_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
@@ -31,6 +31,6 @@ from .utils import dist, unif, tic, toc, toq
__version__ = "0.5.1"
-__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
+__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 02cbd8c..49ded5b 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -14,12 +14,12 @@ import numpy as np
from .import cvx
# import compiled emd
-from .emd_wrap import emd_c, check_result
+from .emd_wrap import emd_c, check_result, emd_1d_sorted
from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted']
def emd(a, b, M, numItermax=100000, log=False):
@@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- # if empty array given then use unifor distributions
+ # if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
@@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- # if empty array given then use unifor distributions
+ # if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
@@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
- return X \ No newline at end of file
+ return X
+
+
+def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
+ """Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+ """
+ assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \
+ "with monodimensional data"
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+
+ # if empty array given then use uniform distributions
+ if len(a) == 0:
+ a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
+ if len(b) == 0:
+ b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+
+ perm_a = np.argsort(x_a.reshape((-1, )))
+ perm_b = np.argsort(x_b.reshape((-1, )))
+ inv_perm_a = np.argsort(perm_a)
+ inv_perm_b = np.argsort(perm_b)
+
+ M = dist(x_a[perm_a], x_b[perm_b], metric=metric)
+
+ G_sorted, cost = emd_1d_sorted(a, b, M)
+ G = G_sorted[inv_perm_a, :][:, inv_perm_b]
+ if log:
+ log = {}
+ log['cost'] = cost
+ return G, log
+ return G \ No newline at end of file
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 83ee6aa..a3d189d 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -93,3 +93,38 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
return G, cost, alpha, beta, result_code
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
+ np.ndarray[double, ndim=1, mode="c"] v_weights,
+ np.ndarray[double, ndim=2, mode="c"] M):
+ r"""
+ Roro's stuff
+ """
+ cdef double cost = 0.
+ cdef int n = u_weights.shape[0]
+ cdef int m = v_weights.shape[0]
+
+ cdef int i = 0
+ cdef double w_i = u_weights[0]
+ cdef int j = 0
+ cdef double w_j = v_weights[0]
+
+ cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
+ dtype=np.float64)
+ while i < n and j < m:
+ if w_i < w_j or j == m - 1:
+ cost += M[i, j] * w_i
+ G[i, j] = w_i
+ i += 1
+ w_j -= w_i
+ w_i = u_weights[i]
+ else:
+ cost += M[i, j] * w_j
+ G[i, j] = w_j
+ j += 1
+ w_i -= w_j
+ w_j = v_weights[j]
+ return G, cost \ No newline at end of file
diff --git a/test/test_ot.py b/test/test_ot.py
index 7652394..7008002 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -46,6 +46,32 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_emd1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ u = np.random.randn(n, 1)
+ v = np.random.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100