summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:21:08 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:21:08 +0200
commit18502d6861a4977cbade957f2e48eeb8dbb55414 (patch)
tree947cf67b5c118ba6eafd72e38ccb0977085767ca
parentcada9a3019997e8efb95d96c86985110f1e937b9 (diff)
Sparse G matrix for EMD1d + standard metrics computed without cdist
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/lp/emd_wrap.pyx29
-rw-r--r--test/test_ot.py23
3 files changed, 41 insertions, 15 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 5d5b700..f0e526c 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -22,7 +22,7 @@ from . import smooth
from . import stochastic
# OT functions
-from .lp import emd, emd2, emd_1d
+from .lp import emd, emd2, emd_1d, emd2_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
@@ -32,5 +32,5 @@ from .utils import dist, unif, tic, toc, toq
__version__ = "0.5.1"
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
- 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
+ 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 2966206..ab88d7f 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -101,8 +101,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
@cython.wraparound(False)
def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
np.ndarray[double, ndim=1, mode="c"] v_weights,
- np.ndarray[double, ndim=2, mode="c"] u,
- np.ndarray[double, ndim=2, mode="c"] v,
+ np.ndarray[double, ndim=1, mode="c"] u,
+ np.ndarray[double, ndim=1, mode="c"] v,
str metric='sqeuclidean'):
r"""
Roro's stuff
@@ -118,21 +118,34 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cdef double m_ij = 0.
- cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
+ cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
dtype=np.float64)
+ cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int)
+ cdef int cur_idx = 0
while i < n and j < m:
- m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
- metric=metric)[0, 0]
+ if metric == 'sqeuclidean':
+ m_ij = (u[i] - v[j]) ** 2
+ elif metric == 'cityblock' or metric == 'euclidean':
+ m_ij = np.abs(u[i] - v[j])
+ else:
+ m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
+ metric=metric)[0, 0]
if w_i < w_j or j == m - 1:
cost += m_ij * w_i
- G[i, j] = w_i
+ G[cur_idx] = w_i
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
i += 1
w_j -= w_i
w_i = u_weights[i]
else:
cost += m_ij * w_j
- G[i, j] = w_j
+ G[cur_idx] = w_j
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
j += 1
w_i -= w_j
w_j = v_weights[j]
- return G, cost \ No newline at end of file
+ cur_idx += 1
+ return G[:cur_idx], indices[:cur_idx], cost
diff --git a/test/test_ot.py b/test/test_ot.py
index 7008002..2a2e0a5 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,6 +7,7 @@
import warnings
import numpy as np
+from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
@@ -37,7 +38,7 @@ def test_emd_emd2():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -46,12 +47,13 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
-def test_emd1d():
+def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
m = 30
- u = np.random.randn(n, 1)
- v = np.random.randn(m, 1)
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
M = ot.dist(u, v, metric='sqeuclidean')
@@ -59,9 +61,20 @@ def test_emd1d():
wass = log["cost"]
G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False)
# check loss is similar
np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
# check G is similar
np.testing.assert_allclose(G, G_1d)
@@ -86,7 +99,7 @@ def test_emd_empty():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn